init: add source code from src.zip
This commit is contained in:
179
src/services/AgentSummary/agentSummary.ts
Normal file
179
src/services/AgentSummary/agentSummary.ts
Normal file
@@ -0,0 +1,179 @@
|
||||
/**
|
||||
* Periodic background summarization for coordinator mode sub-agents.
|
||||
*
|
||||
* Forks the sub-agent's conversation every ~30s using runForkedAgent()
|
||||
* to generate a 1-2 sentence progress summary. The summary is stored
|
||||
* on AgentProgress for UI display.
|
||||
*
|
||||
* Cache sharing: uses the same CacheSafeParams as the parent agent
|
||||
* to share the prompt cache. Tools are kept in the request for cache
|
||||
* key matching but denied via canUseTool callback.
|
||||
*/
|
||||
|
||||
import type { TaskContext } from '../../Task.js'
|
||||
import { updateAgentSummary } from '../../tasks/LocalAgentTask/LocalAgentTask.js'
|
||||
import { filterIncompleteToolCalls } from '../../tools/AgentTool/runAgent.js'
|
||||
import type { AgentId } from '../../types/ids.js'
|
||||
import { logForDebugging } from '../../utils/debug.js'
|
||||
import {
|
||||
type CacheSafeParams,
|
||||
runForkedAgent,
|
||||
} from '../../utils/forkedAgent.js'
|
||||
import { logError } from '../../utils/log.js'
|
||||
import { createUserMessage } from '../../utils/messages.js'
|
||||
import { getAgentTranscript } from '../../utils/sessionStorage.js'
|
||||
|
||||
const SUMMARY_INTERVAL_MS = 30_000
|
||||
|
||||
function buildSummaryPrompt(previousSummary: string | null): string {
|
||||
const prevLine = previousSummary
|
||||
? `\nPrevious: "${previousSummary}" — say something NEW.\n`
|
||||
: ''
|
||||
|
||||
return `Describe your most recent action in 3-5 words using present tense (-ing). Name the file or function, not the branch. Do not use tools.
|
||||
${prevLine}
|
||||
Good: "Reading runAgent.ts"
|
||||
Good: "Fixing null check in validate.ts"
|
||||
Good: "Running auth module tests"
|
||||
Good: "Adding retry logic to fetchUser"
|
||||
|
||||
Bad (past tense): "Analyzed the branch diff"
|
||||
Bad (too vague): "Investigating the issue"
|
||||
Bad (too long): "Reviewing full branch diff and AgentTool.tsx integration"
|
||||
Bad (branch name): "Analyzed adam/background-summary branch diff"`
|
||||
}
|
||||
|
||||
export function startAgentSummarization(
|
||||
taskId: string,
|
||||
agentId: AgentId,
|
||||
cacheSafeParams: CacheSafeParams,
|
||||
setAppState: TaskContext['setAppState'],
|
||||
): { stop: () => void } {
|
||||
// Drop forkContextMessages from the closure — runSummary rebuilds it each
|
||||
// tick from getAgentTranscript(). Without this, the original fork messages
|
||||
// (passed from AgentTool.tsx) are pinned for the lifetime of the timer.
|
||||
const { forkContextMessages: _drop, ...baseParams } = cacheSafeParams
|
||||
let summaryAbortController: AbortController | null = null
|
||||
let timeoutId: ReturnType<typeof setTimeout> | null = null
|
||||
let stopped = false
|
||||
let previousSummary: string | null = null
|
||||
|
||||
async function runSummary(): Promise<void> {
|
||||
if (stopped) return
|
||||
|
||||
logForDebugging(`[AgentSummary] Timer fired for agent ${agentId}`)
|
||||
|
||||
try {
|
||||
// Read current messages from transcript
|
||||
const transcript = await getAgentTranscript(agentId)
|
||||
if (!transcript || transcript.messages.length < 3) {
|
||||
// Not enough context yet — finally block will schedule next attempt
|
||||
logForDebugging(
|
||||
`[AgentSummary] Skipping summary for ${taskId}: not enough messages (${transcript?.messages.length ?? 0})`,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Filter to clean message state
|
||||
const cleanMessages = filterIncompleteToolCalls(transcript.messages)
|
||||
|
||||
// Build fork params with current messages
|
||||
const forkParams: CacheSafeParams = {
|
||||
...baseParams,
|
||||
forkContextMessages: cleanMessages,
|
||||
}
|
||||
|
||||
logForDebugging(
|
||||
`[AgentSummary] Forking for summary, ${cleanMessages.length} messages in context`,
|
||||
)
|
||||
|
||||
// Create abort controller for this summary
|
||||
summaryAbortController = new AbortController()
|
||||
|
||||
// Deny tools via callback, NOT by passing tools:[] - that busts cache
|
||||
const canUseTool = async () => ({
|
||||
behavior: 'deny' as const,
|
||||
message: 'No tools needed for summary',
|
||||
decisionReason: { type: 'other' as const, reason: 'summary only' },
|
||||
})
|
||||
|
||||
// DO NOT set maxOutputTokens here. The fork piggybacks on the main
|
||||
// thread's prompt cache by sending identical cache-key params (system,
|
||||
// tools, model, messages prefix, thinking config). Setting maxOutputTokens
|
||||
// would clamp budget_tokens, creating a thinking config mismatch that
|
||||
// invalidates the cache.
|
||||
//
|
||||
// ContentReplacementState is cloned by default in createSubagentContext
|
||||
// from forkParams.toolUseContext (the subagent's LIVE state captured at
|
||||
// onCacheSafeParams time). No explicit override needed.
|
||||
const result = await runForkedAgent({
|
||||
promptMessages: [
|
||||
createUserMessage({ content: buildSummaryPrompt(previousSummary) }),
|
||||
],
|
||||
cacheSafeParams: forkParams,
|
||||
canUseTool,
|
||||
querySource: 'agent_summary',
|
||||
forkLabel: 'agent_summary',
|
||||
overrides: { abortController: summaryAbortController },
|
||||
skipTranscript: true,
|
||||
})
|
||||
|
||||
if (stopped) return
|
||||
|
||||
// Extract summary text from result
|
||||
for (const msg of result.messages) {
|
||||
if (msg.type !== 'assistant') continue
|
||||
// Skip API error messages
|
||||
if (msg.isApiErrorMessage) {
|
||||
logForDebugging(
|
||||
`[AgentSummary] Skipping API error message for ${taskId}`,
|
||||
)
|
||||
continue
|
||||
}
|
||||
const textBlock = msg.message.content.find(b => b.type === 'text')
|
||||
if (textBlock?.type === 'text' && textBlock.text.trim()) {
|
||||
const summaryText = textBlock.text.trim()
|
||||
logForDebugging(
|
||||
`[AgentSummary] Summary result for ${taskId}: ${summaryText}`,
|
||||
)
|
||||
previousSummary = summaryText
|
||||
updateAgentSummary(taskId, summaryText, setAppState)
|
||||
break
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
if (!stopped && e instanceof Error) {
|
||||
logError(e)
|
||||
}
|
||||
} finally {
|
||||
summaryAbortController = null
|
||||
// Reset timer on completion (not initiation) to prevent overlapping summaries
|
||||
if (!stopped) {
|
||||
scheduleNext()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function scheduleNext(): void {
|
||||
if (stopped) return
|
||||
timeoutId = setTimeout(runSummary, SUMMARY_INTERVAL_MS)
|
||||
}
|
||||
|
||||
function stop(): void {
|
||||
logForDebugging(`[AgentSummary] Stopping summarization for ${taskId}`)
|
||||
stopped = true
|
||||
if (timeoutId) {
|
||||
clearTimeout(timeoutId)
|
||||
timeoutId = null
|
||||
}
|
||||
if (summaryAbortController) {
|
||||
summaryAbortController.abort()
|
||||
summaryAbortController = null
|
||||
}
|
||||
}
|
||||
|
||||
// Start the first timer
|
||||
scheduleNext()
|
||||
|
||||
return { stop }
|
||||
}
|
||||
254
src/services/MagicDocs/magicDocs.ts
Normal file
254
src/services/MagicDocs/magicDocs.ts
Normal file
@@ -0,0 +1,254 @@
|
||||
/**
|
||||
* Magic Docs automatically maintains markdown documentation files marked with special headers.
|
||||
* When a file with "# MAGIC DOC: [title]" is read, it runs periodically in the background
|
||||
* using a forked subagent to update the document with new learnings from the conversation.
|
||||
*
|
||||
* See docs/magic-docs.md for more information.
|
||||
*/
|
||||
|
||||
import type { Tool, ToolUseContext } from '../../Tool.js'
|
||||
import type { BuiltInAgentDefinition } from '../../tools/AgentTool/loadAgentsDir.js'
|
||||
import { runAgent } from '../../tools/AgentTool/runAgent.js'
|
||||
import { FILE_EDIT_TOOL_NAME } from '../../tools/FileEditTool/constants.js'
|
||||
import {
|
||||
FileReadTool,
|
||||
type Output as FileReadToolOutput,
|
||||
registerFileReadListener,
|
||||
} from '../../tools/FileReadTool/FileReadTool.js'
|
||||
import { isFsInaccessible } from '../../utils/errors.js'
|
||||
import { cloneFileStateCache } from '../../utils/fileStateCache.js'
|
||||
import {
|
||||
type REPLHookContext,
|
||||
registerPostSamplingHook,
|
||||
} from '../../utils/hooks/postSamplingHooks.js'
|
||||
import {
|
||||
createUserMessage,
|
||||
hasToolCallsInLastAssistantTurn,
|
||||
} from '../../utils/messages.js'
|
||||
import { sequential } from '../../utils/sequential.js'
|
||||
import { buildMagicDocsUpdatePrompt } from './prompts.js'
|
||||
|
||||
// Magic Doc header pattern: # MAGIC DOC: [title]
|
||||
// Matches at the start of the file (first line)
|
||||
const MAGIC_DOC_HEADER_PATTERN = /^#\s*MAGIC\s+DOC:\s*(.+)$/im
|
||||
// Pattern to match italics on the line immediately after the header
|
||||
const ITALICS_PATTERN = /^[_*](.+?)[_*]\s*$/m
|
||||
|
||||
// Track magic docs
|
||||
type MagicDocInfo = {
|
||||
path: string
|
||||
}
|
||||
|
||||
const trackedMagicDocs = new Map<string, MagicDocInfo>()
|
||||
|
||||
export function clearTrackedMagicDocs(): void {
|
||||
trackedMagicDocs.clear()
|
||||
}
|
||||
|
||||
/**
|
||||
* Detect if a file content contains a Magic Doc header
|
||||
* Returns an object with title and optional instructions, or null if not a magic doc
|
||||
*/
|
||||
export function detectMagicDocHeader(
|
||||
content: string,
|
||||
): { title: string; instructions?: string } | null {
|
||||
const match = content.match(MAGIC_DOC_HEADER_PATTERN)
|
||||
if (!match || !match[1]) {
|
||||
return null
|
||||
}
|
||||
|
||||
const title = match[1].trim()
|
||||
|
||||
// Look for italics on the next line after the header (allow one optional blank line)
|
||||
const headerEndIndex = match.index! + match[0].length
|
||||
const afterHeader = content.slice(headerEndIndex)
|
||||
// Match: newline, optional blank line, then content line
|
||||
const nextLineMatch = afterHeader.match(/^\s*\n(?:\s*\n)?(.+?)(?:\n|$)/)
|
||||
|
||||
if (nextLineMatch && nextLineMatch[1]) {
|
||||
const nextLine = nextLineMatch[1]
|
||||
const italicsMatch = nextLine.match(ITALICS_PATTERN)
|
||||
if (italicsMatch && italicsMatch[1]) {
|
||||
const instructions = italicsMatch[1].trim()
|
||||
return {
|
||||
title,
|
||||
instructions,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { title }
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a file as a Magic Doc when it's read
|
||||
* Only registers once per file path - the hook always reads latest content
|
||||
*/
|
||||
export function registerMagicDoc(filePath: string): void {
|
||||
// Only register if not already tracked
|
||||
if (!trackedMagicDocs.has(filePath)) {
|
||||
trackedMagicDocs.set(filePath, {
|
||||
path: filePath,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create Magic Docs agent definition
|
||||
*/
|
||||
function getMagicDocsAgent(): BuiltInAgentDefinition {
|
||||
return {
|
||||
agentType: 'magic-docs',
|
||||
whenToUse: 'Update Magic Docs',
|
||||
tools: [FILE_EDIT_TOOL_NAME], // Only allow Edit
|
||||
model: 'sonnet',
|
||||
source: 'built-in',
|
||||
baseDir: 'built-in',
|
||||
getSystemPrompt: () => '', // Will use override systemPrompt
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update a single Magic Doc
|
||||
*/
|
||||
async function updateMagicDoc(
|
||||
docInfo: MagicDocInfo,
|
||||
context: REPLHookContext,
|
||||
): Promise<void> {
|
||||
const { messages, systemPrompt, userContext, systemContext, toolUseContext } =
|
||||
context
|
||||
|
||||
// Clone the FileStateCache to isolate Magic Docs operations. Delete this
|
||||
// doc's entry so FileReadTool's dedup doesn't return a file_unchanged
|
||||
// stub — we need the actual content to re-detect the header.
|
||||
const clonedReadFileState = cloneFileStateCache(toolUseContext.readFileState)
|
||||
clonedReadFileState.delete(docInfo.path)
|
||||
const clonedToolUseContext: ToolUseContext = {
|
||||
...toolUseContext,
|
||||
readFileState: clonedReadFileState,
|
||||
}
|
||||
|
||||
// Read the document; if deleted or unreadable, remove from tracking
|
||||
let currentDoc = ''
|
||||
try {
|
||||
const result = await FileReadTool.call(
|
||||
{ file_path: docInfo.path },
|
||||
clonedToolUseContext,
|
||||
)
|
||||
const output = result.data as FileReadToolOutput
|
||||
if (output.type === 'text') {
|
||||
currentDoc = output.file.content
|
||||
}
|
||||
} catch (e: unknown) {
|
||||
// FileReadTool wraps ENOENT in a plain Error("File does not exist...") with
|
||||
// no .code, so check the message in addition to isFsInaccessible (EACCES/EPERM).
|
||||
if (
|
||||
isFsInaccessible(e) ||
|
||||
(e instanceof Error && e.message.startsWith('File does not exist'))
|
||||
) {
|
||||
trackedMagicDocs.delete(docInfo.path)
|
||||
return
|
||||
}
|
||||
throw e
|
||||
}
|
||||
|
||||
// Re-detect title and instructions from latest file content
|
||||
const detected = detectMagicDocHeader(currentDoc)
|
||||
if (!detected) {
|
||||
// File no longer has magic doc header, remove from tracking
|
||||
trackedMagicDocs.delete(docInfo.path)
|
||||
return
|
||||
}
|
||||
|
||||
// Build update prompt with latest title and instructions
|
||||
const userPrompt = await buildMagicDocsUpdatePrompt(
|
||||
currentDoc,
|
||||
docInfo.path,
|
||||
detected.title,
|
||||
detected.instructions,
|
||||
)
|
||||
|
||||
// Create a custom canUseTool that only allows Edit for magic doc files
|
||||
const canUseTool = async (tool: Tool, input: unknown) => {
|
||||
if (
|
||||
tool.name === FILE_EDIT_TOOL_NAME &&
|
||||
typeof input === 'object' &&
|
||||
input !== null &&
|
||||
'file_path' in input
|
||||
) {
|
||||
const filePath = input.file_path
|
||||
if (typeof filePath === 'string' && filePath === docInfo.path) {
|
||||
return { behavior: 'allow' as const, updatedInput: input }
|
||||
}
|
||||
}
|
||||
return {
|
||||
behavior: 'deny' as const,
|
||||
message: `only ${FILE_EDIT_TOOL_NAME} is allowed for ${docInfo.path}`,
|
||||
decisionReason: {
|
||||
type: 'other' as const,
|
||||
reason: `only ${FILE_EDIT_TOOL_NAME} is allowed`,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Run Magic Docs update using runAgent with forked context
|
||||
for await (const _message of runAgent({
|
||||
agentDefinition: getMagicDocsAgent(),
|
||||
promptMessages: [createUserMessage({ content: userPrompt })],
|
||||
toolUseContext: clonedToolUseContext,
|
||||
canUseTool,
|
||||
isAsync: true,
|
||||
forkContextMessages: messages,
|
||||
querySource: 'magic_docs',
|
||||
override: {
|
||||
systemPrompt,
|
||||
userContext,
|
||||
systemContext,
|
||||
},
|
||||
availableTools: clonedToolUseContext.options.tools,
|
||||
})) {
|
||||
// Just consume - let it run to completion
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Magic Docs post-sampling hook that updates all tracked Magic Docs
|
||||
*/
|
||||
const updateMagicDocs = sequential(async function (
|
||||
context: REPLHookContext,
|
||||
): Promise<void> {
|
||||
const { messages, querySource } = context
|
||||
|
||||
if (querySource !== 'repl_main_thread') {
|
||||
return
|
||||
}
|
||||
|
||||
// Only update when conversation is idle (no tool calls in last turn)
|
||||
const hasToolCalls = hasToolCallsInLastAssistantTurn(messages)
|
||||
if (hasToolCalls) {
|
||||
return
|
||||
}
|
||||
|
||||
const docCount = trackedMagicDocs.size
|
||||
if (docCount === 0) {
|
||||
return
|
||||
}
|
||||
|
||||
for (const docInfo of Array.from(trackedMagicDocs.values())) {
|
||||
await updateMagicDoc(docInfo, context)
|
||||
}
|
||||
})
|
||||
|
||||
export async function initMagicDocs(): Promise<void> {
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
// Register listener to detect magic docs when files are read
|
||||
registerFileReadListener((filePath: string, content: string) => {
|
||||
const result = detectMagicDocHeader(content)
|
||||
if (result) {
|
||||
registerMagicDoc(filePath)
|
||||
}
|
||||
})
|
||||
|
||||
registerPostSamplingHook(updateMagicDocs)
|
||||
}
|
||||
}
|
||||
127
src/services/MagicDocs/prompts.ts
Normal file
127
src/services/MagicDocs/prompts.ts
Normal file
@@ -0,0 +1,127 @@
|
||||
import { join } from 'path'
|
||||
import { getClaudeConfigHomeDir } from '../../utils/envUtils.js'
|
||||
import { getFsImplementation } from '../../utils/fsOperations.js'
|
||||
|
||||
/**
|
||||
* Get the Magic Docs update prompt template
|
||||
*/
|
||||
function getUpdatePromptTemplate(): string {
|
||||
return `IMPORTANT: This message and these instructions are NOT part of the actual user conversation. Do NOT include any references to "documentation updates", "magic docs", or these update instructions in the document content.
|
||||
|
||||
Based on the user conversation above (EXCLUDING this documentation update instruction message), update the Magic Doc file to incorporate any NEW learnings, insights, or information that would be valuable to preserve.
|
||||
|
||||
The file {{docPath}} has already been read for you. Here are its current contents:
|
||||
<current_doc_content>
|
||||
{{docContents}}
|
||||
</current_doc_content>
|
||||
|
||||
Document title: {{docTitle}}
|
||||
{{customInstructions}}
|
||||
|
||||
Your ONLY task is to use the Edit tool to update the documentation file if there is substantial new information to add, then stop. You can make multiple edits (update multiple sections as needed) - make all Edit tool calls in parallel in a single message. If there's nothing substantial to add, simply respond with a brief explanation and do not call any tools.
|
||||
|
||||
CRITICAL RULES FOR EDITING:
|
||||
- Preserve the Magic Doc header exactly as-is: # MAGIC DOC: {{docTitle}}
|
||||
- If there's an italicized line immediately after the header, preserve it exactly as-is
|
||||
- Keep the document CURRENT with the latest state of the codebase - this is NOT a changelog or history
|
||||
- Update information IN-PLACE to reflect the current state - do NOT append historical notes or track changes over time
|
||||
- Remove or replace outdated information rather than adding "Previously..." or "Updated to..." notes
|
||||
- Clean up or DELETE sections that are no longer relevant or don't align with the document's purpose
|
||||
- Fix obvious errors: typos, grammar mistakes, broken formatting, incorrect information, or confusing statements
|
||||
- Keep the document well organized: use clear headings, logical section order, consistent formatting, and proper nesting
|
||||
|
||||
DOCUMENTATION PHILOSOPHY - READ CAREFULLY:
|
||||
- BE TERSE. High signal only. No filler words or unnecessary elaboration.
|
||||
- Documentation is for OVERVIEWS, ARCHITECTURE, and ENTRY POINTS - not detailed code walkthroughs
|
||||
- Do NOT duplicate information that's already obvious from reading the source code
|
||||
- Do NOT document every function, parameter, or line number reference
|
||||
- Focus on: WHY things exist, HOW components connect, WHERE to start reading, WHAT patterns are used
|
||||
- Skip: detailed implementation steps, exhaustive API docs, play-by-play narratives
|
||||
|
||||
What TO document:
|
||||
- High-level architecture and system design
|
||||
- Non-obvious patterns, conventions, or gotchas
|
||||
- Key entry points and where to start reading code
|
||||
- Important design decisions and their rationale
|
||||
- Critical dependencies or integration points
|
||||
- References to related files, docs, or code (like a wiki) - help readers navigate to relevant context
|
||||
|
||||
What NOT to document:
|
||||
- Anything obvious from reading the code itself
|
||||
- Exhaustive lists of files, functions, or parameters
|
||||
- Step-by-step implementation details
|
||||
- Low-level code mechanics
|
||||
- Information already in CLAUDE.md or other project docs
|
||||
|
||||
Use the Edit tool with file_path: {{docPath}}
|
||||
|
||||
REMEMBER: Only update if there is substantial new information. The Magic Doc header (# MAGIC DOC: {{docTitle}}) must remain unchanged.`
|
||||
}
|
||||
|
||||
/**
|
||||
* Load custom Magic Docs prompt from file if it exists
|
||||
* Custom prompts can be placed at ~/.claude/magic-docs/prompt.md
|
||||
* Use {{variableName}} syntax for variable substitution (e.g., {{docContents}}, {{docPath}}, {{docTitle}})
|
||||
*/
|
||||
async function loadMagicDocsPrompt(): Promise<string> {
|
||||
const fs = getFsImplementation()
|
||||
const promptPath = join(getClaudeConfigHomeDir(), 'magic-docs', 'prompt.md')
|
||||
|
||||
try {
|
||||
return await fs.readFile(promptPath, { encoding: 'utf-8' })
|
||||
} catch {
|
||||
// Silently fall back to default if custom prompt doesn't exist or fails to load
|
||||
return getUpdatePromptTemplate()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Substitute variables in the prompt template using {{variable}} syntax
|
||||
*/
|
||||
function substituteVariables(
|
||||
template: string,
|
||||
variables: Record<string, string>,
|
||||
): string {
|
||||
// Single-pass replacement avoids two bugs: (1) $ backreference corruption
|
||||
// (replacer fn treats $ literally), and (2) double-substitution when user
|
||||
// content happens to contain {{varName}} matching a later variable.
|
||||
return template.replace(/\{\{(\w+)\}\}/g, (match, key: string) =>
|
||||
Object.prototype.hasOwnProperty.call(variables, key)
|
||||
? variables[key]!
|
||||
: match,
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Build the Magic Docs update prompt with variable substitution
|
||||
*/
|
||||
export async function buildMagicDocsUpdatePrompt(
|
||||
docContents: string,
|
||||
docPath: string,
|
||||
docTitle: string,
|
||||
instructions?: string,
|
||||
): Promise<string> {
|
||||
const promptTemplate = await loadMagicDocsPrompt()
|
||||
|
||||
// Build custom instructions section if provided
|
||||
const customInstructions = instructions
|
||||
? `
|
||||
|
||||
DOCUMENT-SPECIFIC UPDATE INSTRUCTIONS:
|
||||
The document author has provided specific instructions for how this file should be updated. Pay extra attention to these instructions and follow them carefully:
|
||||
|
||||
"${instructions}"
|
||||
|
||||
These instructions take priority over the general rules below. Make sure your updates align with these specific guidelines.`
|
||||
: ''
|
||||
|
||||
// Substitute variables in the prompt
|
||||
const variables = {
|
||||
docContents,
|
||||
docPath,
|
||||
docTitle,
|
||||
customInstructions,
|
||||
}
|
||||
|
||||
return substituteVariables(promptTemplate, variables)
|
||||
}
|
||||
523
src/services/PromptSuggestion/promptSuggestion.ts
Normal file
523
src/services/PromptSuggestion/promptSuggestion.ts
Normal file
@@ -0,0 +1,523 @@
|
||||
import { getIsNonInteractiveSession } from '../../bootstrap/state.js'
|
||||
import type { AppState } from '../../state/AppState.js'
|
||||
import type { Message } from '../../types/message.js'
|
||||
import { isAgentSwarmsEnabled } from '../../utils/agentSwarmsEnabled.js'
|
||||
import { count } from '../../utils/array.js'
|
||||
import { isEnvDefinedFalsy, isEnvTruthy } from '../../utils/envUtils.js'
|
||||
import { toError } from '../../utils/errors.js'
|
||||
import {
|
||||
type CacheSafeParams,
|
||||
createCacheSafeParams,
|
||||
runForkedAgent,
|
||||
} from '../../utils/forkedAgent.js'
|
||||
import type { REPLHookContext } from '../../utils/hooks/postSamplingHooks.js'
|
||||
import { logError } from '../../utils/log.js'
|
||||
import {
|
||||
createUserMessage,
|
||||
getLastAssistantMessage,
|
||||
} from '../../utils/messages.js'
|
||||
import { getInitialSettings } from '../../utils/settings/settings.js'
|
||||
import { isTeammate } from '../../utils/teammate.js'
|
||||
import { getFeatureValue_CACHED_MAY_BE_STALE } from '../analytics/growthbook.js'
|
||||
import {
|
||||
type AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
logEvent,
|
||||
} from '../analytics/index.js'
|
||||
import { currentLimits } from '../claudeAiLimits.js'
|
||||
import { isSpeculationEnabled, startSpeculation } from './speculation.js'
|
||||
|
||||
let currentAbortController: AbortController | null = null
|
||||
|
||||
export type PromptVariant = 'user_intent' | 'stated_intent'
|
||||
|
||||
export function getPromptVariant(): PromptVariant {
|
||||
return 'user_intent'
|
||||
}
|
||||
|
||||
export function shouldEnablePromptSuggestion(): boolean {
|
||||
// Env var overrides everything (for testing)
|
||||
const envOverride = process.env.CLAUDE_CODE_ENABLE_PROMPT_SUGGESTION
|
||||
if (isEnvDefinedFalsy(envOverride)) {
|
||||
logEvent('tengu_prompt_suggestion_init', {
|
||||
enabled: false,
|
||||
source:
|
||||
'env' as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
return false
|
||||
}
|
||||
if (isEnvTruthy(envOverride)) {
|
||||
logEvent('tengu_prompt_suggestion_init', {
|
||||
enabled: true,
|
||||
source:
|
||||
'env' as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
// Keep default in sync with Config.tsx (settings toggle visibility)
|
||||
if (!getFeatureValue_CACHED_MAY_BE_STALE('tengu_chomp_inflection', false)) {
|
||||
logEvent('tengu_prompt_suggestion_init', {
|
||||
enabled: false,
|
||||
source:
|
||||
'growthbook' as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
return false
|
||||
}
|
||||
|
||||
// Disable in non-interactive mode (print mode, piped input, SDK)
|
||||
if (getIsNonInteractiveSession()) {
|
||||
logEvent('tengu_prompt_suggestion_init', {
|
||||
enabled: false,
|
||||
source:
|
||||
'non_interactive' as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
return false
|
||||
}
|
||||
|
||||
// Disable for swarm teammates (only leader should show suggestions)
|
||||
if (isAgentSwarmsEnabled() && isTeammate()) {
|
||||
logEvent('tengu_prompt_suggestion_init', {
|
||||
enabled: false,
|
||||
source:
|
||||
'swarm_teammate' as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
return false
|
||||
}
|
||||
|
||||
const enabled = getInitialSettings()?.promptSuggestionEnabled !== false
|
||||
logEvent('tengu_prompt_suggestion_init', {
|
||||
enabled,
|
||||
source:
|
||||
'setting' as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
return enabled
|
||||
}
|
||||
|
||||
export function abortPromptSuggestion(): void {
|
||||
if (currentAbortController) {
|
||||
currentAbortController.abort()
|
||||
currentAbortController = null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a suppression reason if suggestions should not be generated,
|
||||
* or null if generation is allowed. Shared by main and pipelined paths.
|
||||
*/
|
||||
export function getSuggestionSuppressReason(appState: AppState): string | null {
|
||||
if (!appState.promptSuggestionEnabled) return 'disabled'
|
||||
if (appState.pendingWorkerRequest || appState.pendingSandboxRequest)
|
||||
return 'pending_permission'
|
||||
if (appState.elicitation.queue.length > 0) return 'elicitation_active'
|
||||
if (appState.toolPermissionContext.mode === 'plan') return 'plan_mode'
|
||||
if (
|
||||
process.env.USER_TYPE === 'external' &&
|
||||
currentLimits.status !== 'allowed'
|
||||
)
|
||||
return 'rate_limit'
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* Shared guard + generation logic used by both CLI TUI and SDK push paths.
|
||||
* Returns the suggestion with metadata, or null if suppressed/filtered.
|
||||
*/
|
||||
export async function tryGenerateSuggestion(
|
||||
abortController: AbortController,
|
||||
messages: Message[],
|
||||
getAppState: () => AppState,
|
||||
cacheSafeParams: CacheSafeParams,
|
||||
source?: 'cli' | 'sdk',
|
||||
): Promise<{
|
||||
suggestion: string
|
||||
promptId: PromptVariant
|
||||
generationRequestId: string | null
|
||||
} | null> {
|
||||
if (abortController.signal.aborted) {
|
||||
logSuggestionSuppressed('aborted', undefined, undefined, source)
|
||||
return null
|
||||
}
|
||||
|
||||
const assistantTurnCount = count(messages, m => m.type === 'assistant')
|
||||
if (assistantTurnCount < 2) {
|
||||
logSuggestionSuppressed('early_conversation', undefined, undefined, source)
|
||||
return null
|
||||
}
|
||||
|
||||
const lastAssistantMessage = getLastAssistantMessage(messages)
|
||||
if (lastAssistantMessage?.isApiErrorMessage) {
|
||||
logSuggestionSuppressed('last_response_error', undefined, undefined, source)
|
||||
return null
|
||||
}
|
||||
const cacheReason = getParentCacheSuppressReason(lastAssistantMessage)
|
||||
if (cacheReason) {
|
||||
logSuggestionSuppressed(cacheReason, undefined, undefined, source)
|
||||
return null
|
||||
}
|
||||
|
||||
const appState = getAppState()
|
||||
const suppressReason = getSuggestionSuppressReason(appState)
|
||||
if (suppressReason) {
|
||||
logSuggestionSuppressed(suppressReason, undefined, undefined, source)
|
||||
return null
|
||||
}
|
||||
|
||||
const promptId = getPromptVariant()
|
||||
const { suggestion, generationRequestId } = await generateSuggestion(
|
||||
abortController,
|
||||
promptId,
|
||||
cacheSafeParams,
|
||||
)
|
||||
if (abortController.signal.aborted) {
|
||||
logSuggestionSuppressed('aborted', undefined, undefined, source)
|
||||
return null
|
||||
}
|
||||
if (!suggestion) {
|
||||
logSuggestionSuppressed('empty', undefined, promptId, source)
|
||||
return null
|
||||
}
|
||||
if (shouldFilterSuggestion(suggestion, promptId, source)) return null
|
||||
|
||||
return { suggestion, promptId, generationRequestId }
|
||||
}
|
||||
|
||||
export async function executePromptSuggestion(
|
||||
context: REPLHookContext,
|
||||
): Promise<void> {
|
||||
if (context.querySource !== 'repl_main_thread') return
|
||||
|
||||
currentAbortController = new AbortController()
|
||||
const abortController = currentAbortController
|
||||
const cacheSafeParams = createCacheSafeParams(context)
|
||||
|
||||
try {
|
||||
const result = await tryGenerateSuggestion(
|
||||
abortController,
|
||||
context.messages,
|
||||
context.toolUseContext.getAppState,
|
||||
cacheSafeParams,
|
||||
'cli',
|
||||
)
|
||||
if (!result) return
|
||||
|
||||
context.toolUseContext.setAppState(prev => ({
|
||||
...prev,
|
||||
promptSuggestion: {
|
||||
text: result.suggestion,
|
||||
promptId: result.promptId,
|
||||
shownAt: 0,
|
||||
acceptedAt: 0,
|
||||
generationRequestId: result.generationRequestId,
|
||||
},
|
||||
}))
|
||||
|
||||
if (isSpeculationEnabled() && result.suggestion) {
|
||||
void startSpeculation(
|
||||
result.suggestion,
|
||||
context,
|
||||
context.toolUseContext.setAppState,
|
||||
false,
|
||||
cacheSafeParams,
|
||||
)
|
||||
}
|
||||
} catch (error) {
|
||||
if (
|
||||
error instanceof Error &&
|
||||
(error.name === 'AbortError' || error.name === 'APIUserAbortError')
|
||||
) {
|
||||
logSuggestionSuppressed('aborted', undefined, undefined, 'cli')
|
||||
return
|
||||
}
|
||||
logError(toError(error))
|
||||
} finally {
|
||||
if (currentAbortController === abortController) {
|
||||
currentAbortController = null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const MAX_PARENT_UNCACHED_TOKENS = 10_000
|
||||
|
||||
export function getParentCacheSuppressReason(
|
||||
lastAssistantMessage: ReturnType<typeof getLastAssistantMessage>,
|
||||
): string | null {
|
||||
if (!lastAssistantMessage) return null
|
||||
|
||||
const usage = lastAssistantMessage.message.usage
|
||||
const inputTokens = usage.input_tokens ?? 0
|
||||
const cacheWriteTokens = usage.cache_creation_input_tokens ?? 0
|
||||
// The fork re-processes the parent's output (never cached) plus its own prompt.
|
||||
const outputTokens = usage.output_tokens ?? 0
|
||||
|
||||
return inputTokens + cacheWriteTokens + outputTokens >
|
||||
MAX_PARENT_UNCACHED_TOKENS
|
||||
? 'cache_cold'
|
||||
: null
|
||||
}
|
||||
|
||||
const SUGGESTION_PROMPT = `[SUGGESTION MODE: Suggest what the user might naturally type next into Claude Code.]
|
||||
|
||||
FIRST: Look at the user's recent messages and original request.
|
||||
|
||||
Your job is to predict what THEY would type - not what you think they should do.
|
||||
|
||||
THE TEST: Would they think "I was just about to type that"?
|
||||
|
||||
EXAMPLES:
|
||||
User asked "fix the bug and run tests", bug is fixed → "run the tests"
|
||||
After code written → "try it out"
|
||||
Claude offers options → suggest the one the user would likely pick, based on conversation
|
||||
Claude asks to continue → "yes" or "go ahead"
|
||||
Task complete, obvious follow-up → "commit this" or "push it"
|
||||
After error or misunderstanding → silence (let them assess/correct)
|
||||
|
||||
Be specific: "run the tests" beats "continue".
|
||||
|
||||
NEVER SUGGEST:
|
||||
- Evaluative ("looks good", "thanks")
|
||||
- Questions ("what about...?")
|
||||
- Claude-voice ("Let me...", "I'll...", "Here's...")
|
||||
- New ideas they didn't ask about
|
||||
- Multiple sentences
|
||||
|
||||
Stay silent if the next step isn't obvious from what the user said.
|
||||
|
||||
Format: 2-12 words, match the user's style. Or nothing.
|
||||
|
||||
Reply with ONLY the suggestion, no quotes or explanation.`
|
||||
|
||||
const SUGGESTION_PROMPTS: Record<PromptVariant, string> = {
|
||||
user_intent: SUGGESTION_PROMPT,
|
||||
stated_intent: SUGGESTION_PROMPT,
|
||||
}
|
||||
|
||||
export async function generateSuggestion(
|
||||
abortController: AbortController,
|
||||
promptId: PromptVariant,
|
||||
cacheSafeParams: CacheSafeParams,
|
||||
): Promise<{ suggestion: string | null; generationRequestId: string | null }> {
|
||||
const prompt = SUGGESTION_PROMPTS[promptId]
|
||||
|
||||
// Deny tools via callback, NOT by passing tools:[] - that busts cache (0% hit)
|
||||
const canUseTool = async () => ({
|
||||
behavior: 'deny' as const,
|
||||
message: 'No tools needed for suggestion',
|
||||
decisionReason: { type: 'other' as const, reason: 'suggestion only' },
|
||||
})
|
||||
|
||||
// DO NOT override any API parameter that differs from the parent request.
|
||||
// The fork piggybacks on the main thread's prompt cache by sending identical
|
||||
// cache-key params. The billing cache key includes more than just
|
||||
// system/tools/model/messages/thinking — empirically, setting effortValue
|
||||
// or maxOutputTokens on the fork (even via output_config or getAppState)
|
||||
// busts cache. PR #18143 tried effort:'low' and caused a 45x spike in cache
|
||||
// writes (92.7% → 61% hit rate). The only safe overrides are:
|
||||
// - abortController (not sent to API)
|
||||
// - skipTranscript (client-side only)
|
||||
// - skipCacheWrite (controls cache_control markers, not the cache key)
|
||||
// - canUseTool (client-side permission check)
|
||||
const result = await runForkedAgent({
|
||||
promptMessages: [createUserMessage({ content: prompt })],
|
||||
cacheSafeParams, // Don't override tools/thinking settings - busts cache
|
||||
canUseTool,
|
||||
querySource: 'prompt_suggestion',
|
||||
forkLabel: 'prompt_suggestion',
|
||||
overrides: {
|
||||
abortController,
|
||||
},
|
||||
skipTranscript: true,
|
||||
skipCacheWrite: true,
|
||||
})
|
||||
|
||||
// Check ALL messages - model may loop (try tool → denied → text in next message)
|
||||
// Also extract the requestId from the first assistant message for RL dataset joins
|
||||
const firstAssistantMsg = result.messages.find(m => m.type === 'assistant')
|
||||
const generationRequestId =
|
||||
firstAssistantMsg?.type === 'assistant'
|
||||
? (firstAssistantMsg.requestId ?? null)
|
||||
: null
|
||||
|
||||
for (const msg of result.messages) {
|
||||
if (msg.type !== 'assistant') continue
|
||||
const textBlock = msg.message.content.find(b => b.type === 'text')
|
||||
if (textBlock?.type === 'text') {
|
||||
const suggestion = textBlock.text.trim()
|
||||
if (suggestion) {
|
||||
return { suggestion, generationRequestId }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { suggestion: null, generationRequestId }
|
||||
}
|
||||
|
||||
export function shouldFilterSuggestion(
|
||||
suggestion: string | null,
|
||||
promptId: PromptVariant,
|
||||
source?: 'cli' | 'sdk',
|
||||
): boolean {
|
||||
if (!suggestion) {
|
||||
logSuggestionSuppressed('empty', undefined, promptId, source)
|
||||
return true
|
||||
}
|
||||
|
||||
const lower = suggestion.toLowerCase()
|
||||
const wordCount = suggestion.trim().split(/\s+/).length
|
||||
|
||||
const filters: Array<[string, () => boolean]> = [
|
||||
['done', () => lower === 'done'],
|
||||
[
|
||||
'meta_text',
|
||||
() =>
|
||||
lower === 'nothing found' ||
|
||||
lower === 'nothing found.' ||
|
||||
lower.startsWith('nothing to suggest') ||
|
||||
lower.startsWith('no suggestion') ||
|
||||
// Model spells out the prompt's "stay silent" instruction
|
||||
/\bsilence is\b|\bstay(s|ing)? silent\b/.test(lower) ||
|
||||
// Model outputs bare "silence" wrapped in punctuation/whitespace
|
||||
/^\W*silence\W*$/.test(lower),
|
||||
],
|
||||
[
|
||||
'meta_wrapped',
|
||||
// Model wraps meta-reasoning in parens/brackets: (silence — ...), [no suggestion]
|
||||
() => /^\(.*\)$|^\[.*\]$/.test(suggestion),
|
||||
],
|
||||
[
|
||||
'error_message',
|
||||
() =>
|
||||
lower.startsWith('api error:') ||
|
||||
lower.startsWith('prompt is too long') ||
|
||||
lower.startsWith('request timed out') ||
|
||||
lower.startsWith('invalid api key') ||
|
||||
lower.startsWith('image was too large'),
|
||||
],
|
||||
['prefixed_label', () => /^\w+:\s/.test(suggestion)],
|
||||
[
|
||||
'too_few_words',
|
||||
() => {
|
||||
if (wordCount >= 2) return false
|
||||
// Allow slash commands — these are valid user commands
|
||||
if (suggestion.startsWith('/')) return false
|
||||
// Allow common single-word inputs that are valid user commands
|
||||
const ALLOWED_SINGLE_WORDS = new Set([
|
||||
// Affirmatives
|
||||
'yes',
|
||||
'yeah',
|
||||
'yep',
|
||||
'yea',
|
||||
'yup',
|
||||
'sure',
|
||||
'ok',
|
||||
'okay',
|
||||
// Actions
|
||||
'push',
|
||||
'commit',
|
||||
'deploy',
|
||||
'stop',
|
||||
'continue',
|
||||
'check',
|
||||
'exit',
|
||||
'quit',
|
||||
// Negation
|
||||
'no',
|
||||
])
|
||||
return !ALLOWED_SINGLE_WORDS.has(lower)
|
||||
},
|
||||
],
|
||||
['too_many_words', () => wordCount > 12],
|
||||
['too_long', () => suggestion.length >= 100],
|
||||
['multiple_sentences', () => /[.!?]\s+[A-Z]/.test(suggestion)],
|
||||
['has_formatting', () => /[\n*]|\*\*/.test(suggestion)],
|
||||
[
|
||||
'evaluative',
|
||||
() =>
|
||||
/thanks|thank you|looks good|sounds good|that works|that worked|that's all|nice|great|perfect|makes sense|awesome|excellent/.test(
|
||||
lower,
|
||||
),
|
||||
],
|
||||
[
|
||||
'claude_voice',
|
||||
() =>
|
||||
/^(let me|i'll|i've|i'm|i can|i would|i think|i notice|here's|here is|here are|that's|this is|this will|you can|you should|you could|sure,|of course|certainly)/i.test(
|
||||
suggestion,
|
||||
),
|
||||
],
|
||||
]
|
||||
|
||||
for (const [reason, check] of filters) {
|
||||
if (check()) {
|
||||
logSuggestionSuppressed(reason, suggestion, promptId, source)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
/**
|
||||
* Log acceptance/ignoring of a prompt suggestion. Used by the SDK push path
|
||||
* to track outcomes when the next user message arrives.
|
||||
*/
|
||||
export function logSuggestionOutcome(
|
||||
suggestion: string,
|
||||
userInput: string,
|
||||
emittedAt: number,
|
||||
promptId: PromptVariant,
|
||||
generationRequestId: string | null,
|
||||
): void {
|
||||
const similarity =
|
||||
Math.round((userInput.length / (suggestion.length || 1)) * 100) / 100
|
||||
const wasAccepted = userInput === suggestion
|
||||
const timeMs = Math.max(0, Date.now() - emittedAt)
|
||||
|
||||
logEvent('tengu_prompt_suggestion', {
|
||||
source: 'sdk' as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
outcome: (wasAccepted
|
||||
? 'accepted'
|
||||
: 'ignored') as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
prompt_id:
|
||||
promptId as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
...(generationRequestId && {
|
||||
generationRequestId:
|
||||
generationRequestId as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
}),
|
||||
...(wasAccepted && {
|
||||
timeToAcceptMs: timeMs,
|
||||
}),
|
||||
...(!wasAccepted && { timeToIgnoreMs: timeMs }),
|
||||
similarity,
|
||||
...(process.env.USER_TYPE === 'ant' && {
|
||||
suggestion:
|
||||
suggestion as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
userInput:
|
||||
userInput as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
export function logSuggestionSuppressed(
|
||||
reason: string,
|
||||
suggestion?: string,
|
||||
promptId?: PromptVariant,
|
||||
source?: 'cli' | 'sdk',
|
||||
): void {
|
||||
const resolvedPromptId = promptId ?? getPromptVariant()
|
||||
logEvent('tengu_prompt_suggestion', {
|
||||
...(source && {
|
||||
source:
|
||||
source as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
}),
|
||||
outcome:
|
||||
'suppressed' as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
reason:
|
||||
reason as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
prompt_id:
|
||||
resolvedPromptId as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
...(process.env.USER_TYPE === 'ant' &&
|
||||
suggestion && {
|
||||
suggestion:
|
||||
suggestion as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
}),
|
||||
})
|
||||
}
|
||||
991
src/services/PromptSuggestion/speculation.ts
Normal file
991
src/services/PromptSuggestion/speculation.ts
Normal file
@@ -0,0 +1,991 @@
|
||||
import { randomUUID } from 'crypto'
|
||||
import { rm } from 'fs'
|
||||
import { appendFile, copyFile, mkdir } from 'fs/promises'
|
||||
import { dirname, isAbsolute, join, relative } from 'path'
|
||||
import { getCwdState } from '../../bootstrap/state.js'
|
||||
import type { CompletionBoundary } from '../../state/AppStateStore.js'
|
||||
import {
|
||||
type AppState,
|
||||
IDLE_SPECULATION_STATE,
|
||||
type SpeculationResult,
|
||||
type SpeculationState,
|
||||
} from '../../state/AppStateStore.js'
|
||||
import { commandHasAnyCd } from '../../tools/BashTool/bashPermissions.js'
|
||||
import { checkReadOnlyConstraints } from '../../tools/BashTool/readOnlyValidation.js'
|
||||
import type { SpeculationAcceptMessage } from '../../types/logs.js'
|
||||
import type { Message } from '../../types/message.js'
|
||||
import { createChildAbortController } from '../../utils/abortController.js'
|
||||
import { count } from '../../utils/array.js'
|
||||
import { getGlobalConfig } from '../../utils/config.js'
|
||||
import { logForDebugging } from '../../utils/debug.js'
|
||||
import { errorMessage } from '../../utils/errors.js'
|
||||
import {
|
||||
type FileStateCache,
|
||||
mergeFileStateCaches,
|
||||
READ_FILE_STATE_CACHE_SIZE,
|
||||
} from '../../utils/fileStateCache.js'
|
||||
import {
|
||||
type CacheSafeParams,
|
||||
createCacheSafeParams,
|
||||
runForkedAgent,
|
||||
} from '../../utils/forkedAgent.js'
|
||||
import { formatDuration, formatNumber } from '../../utils/format.js'
|
||||
import type { REPLHookContext } from '../../utils/hooks/postSamplingHooks.js'
|
||||
import { logError } from '../../utils/log.js'
|
||||
import type { SetAppState } from '../../utils/messageQueueManager.js'
|
||||
import {
|
||||
createSystemMessage,
|
||||
createUserMessage,
|
||||
INTERRUPT_MESSAGE,
|
||||
INTERRUPT_MESSAGE_FOR_TOOL_USE,
|
||||
} from '../../utils/messages.js'
|
||||
import { getClaudeTempDir } from '../../utils/permissions/filesystem.js'
|
||||
import { extractReadFilesFromMessages } from '../../utils/queryHelpers.js'
|
||||
import { getTranscriptPath } from '../../utils/sessionStorage.js'
|
||||
import { jsonStringify } from '../../utils/slowOperations.js'
|
||||
import {
|
||||
type AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
logEvent,
|
||||
} from '../analytics/index.js'
|
||||
import {
|
||||
generateSuggestion,
|
||||
getPromptVariant,
|
||||
getSuggestionSuppressReason,
|
||||
logSuggestionSuppressed,
|
||||
shouldFilterSuggestion,
|
||||
} from './promptSuggestion.js'
|
||||
|
||||
const MAX_SPECULATION_TURNS = 20
|
||||
const MAX_SPECULATION_MESSAGES = 100
|
||||
|
||||
const WRITE_TOOLS = new Set(['Edit', 'Write', 'NotebookEdit'])
|
||||
const SAFE_READ_ONLY_TOOLS = new Set([
|
||||
'Read',
|
||||
'Glob',
|
||||
'Grep',
|
||||
'ToolSearch',
|
||||
'LSP',
|
||||
'TaskGet',
|
||||
'TaskList',
|
||||
])
|
||||
|
||||
function safeRemoveOverlay(overlayPath: string): void {
|
||||
rm(
|
||||
overlayPath,
|
||||
{ recursive: true, force: true, maxRetries: 3, retryDelay: 100 },
|
||||
() => {},
|
||||
)
|
||||
}
|
||||
|
||||
function getOverlayPath(id: string): string {
|
||||
return join(getClaudeTempDir(), 'speculation', String(process.pid), id)
|
||||
}
|
||||
|
||||
function denySpeculation(
|
||||
message: string,
|
||||
reason: string,
|
||||
): {
|
||||
behavior: 'deny'
|
||||
message: string
|
||||
decisionReason: { type: 'other'; reason: string }
|
||||
} {
|
||||
return {
|
||||
behavior: 'deny',
|
||||
message,
|
||||
decisionReason: { type: 'other', reason },
|
||||
}
|
||||
}
|
||||
|
||||
async function copyOverlayToMain(
|
||||
overlayPath: string,
|
||||
writtenPaths: Set<string>,
|
||||
cwd: string,
|
||||
): Promise<boolean> {
|
||||
let allCopied = true
|
||||
for (const rel of writtenPaths) {
|
||||
const src = join(overlayPath, rel)
|
||||
const dest = join(cwd, rel)
|
||||
try {
|
||||
await mkdir(dirname(dest), { recursive: true })
|
||||
await copyFile(src, dest)
|
||||
} catch {
|
||||
allCopied = false
|
||||
logForDebugging(`[Speculation] Failed to copy ${rel} to main`)
|
||||
}
|
||||
}
|
||||
return allCopied
|
||||
}
|
||||
|
||||
export type ActiveSpeculationState = Extract<
|
||||
SpeculationState,
|
||||
{ status: 'active' }
|
||||
>
|
||||
|
||||
function logSpeculation(
|
||||
id: string,
|
||||
outcome: 'accepted' | 'aborted' | 'error',
|
||||
startTime: number,
|
||||
suggestionLength: number,
|
||||
messages: Message[],
|
||||
boundary: CompletionBoundary | null,
|
||||
extras?: Record<string, string | number | boolean | undefined>,
|
||||
): void {
|
||||
logEvent('tengu_speculation', {
|
||||
speculation_id:
|
||||
id as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
outcome:
|
||||
outcome as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
duration_ms: Date.now() - startTime,
|
||||
suggestion_length: suggestionLength,
|
||||
tools_executed: countToolsInMessages(messages),
|
||||
completed: boundary !== null,
|
||||
boundary_type: boundary?.type as
|
||||
| AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS
|
||||
| undefined,
|
||||
boundary_tool: getBoundaryTool(boundary) as
|
||||
| AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS
|
||||
| undefined,
|
||||
boundary_detail: getBoundaryDetail(boundary) as
|
||||
| AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS
|
||||
| undefined,
|
||||
...extras,
|
||||
})
|
||||
}
|
||||
|
||||
function countToolsInMessages(messages: Message[]): number {
|
||||
const blocks = messages
|
||||
.filter(isUserMessageWithArrayContent)
|
||||
.flatMap(m => m.message.content)
|
||||
.filter(
|
||||
(b): b is { type: string; is_error?: boolean } =>
|
||||
typeof b === 'object' && b !== null && 'type' in b,
|
||||
)
|
||||
return count(blocks, b => b.type === 'tool_result' && !b.is_error)
|
||||
}
|
||||
|
||||
function getBoundaryTool(
|
||||
boundary: CompletionBoundary | null,
|
||||
): string | undefined {
|
||||
if (!boundary) return undefined
|
||||
switch (boundary.type) {
|
||||
case 'bash':
|
||||
return 'Bash'
|
||||
case 'edit':
|
||||
case 'denied_tool':
|
||||
return boundary.toolName
|
||||
case 'complete':
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
|
||||
function getBoundaryDetail(
|
||||
boundary: CompletionBoundary | null,
|
||||
): string | undefined {
|
||||
if (!boundary) return undefined
|
||||
switch (boundary.type) {
|
||||
case 'bash':
|
||||
return boundary.command.slice(0, 200)
|
||||
case 'edit':
|
||||
return boundary.filePath
|
||||
case 'denied_tool':
|
||||
return boundary.detail
|
||||
case 'complete':
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
|
||||
function isUserMessageWithArrayContent(
|
||||
m: Message,
|
||||
): m is Message & { message: { content: unknown[] } } {
|
||||
return m.type === 'user' && 'message' in m && Array.isArray(m.message.content)
|
||||
}
|
||||
|
||||
export function prepareMessagesForInjection(messages: Message[]): Message[] {
|
||||
// Find tool_use IDs that have SUCCESSFUL results (not errors/interruptions)
|
||||
// Pending tool_use blocks (no result) and interrupted ones will be stripped
|
||||
type ToolResult = {
|
||||
type: 'tool_result'
|
||||
tool_use_id: string
|
||||
is_error?: boolean
|
||||
content?: unknown
|
||||
}
|
||||
const isToolResult = (b: unknown): b is ToolResult =>
|
||||
typeof b === 'object' &&
|
||||
b !== null &&
|
||||
(b as ToolResult).type === 'tool_result' &&
|
||||
typeof (b as ToolResult).tool_use_id === 'string'
|
||||
const isSuccessful = (b: ToolResult) =>
|
||||
!b.is_error &&
|
||||
!(
|
||||
typeof b.content === 'string' &&
|
||||
b.content.includes(INTERRUPT_MESSAGE_FOR_TOOL_USE)
|
||||
)
|
||||
|
||||
const toolIdsWithSuccessfulResults = new Set(
|
||||
messages
|
||||
.filter(isUserMessageWithArrayContent)
|
||||
.flatMap(m => m.message.content)
|
||||
.filter(isToolResult)
|
||||
.filter(isSuccessful)
|
||||
.map(b => b.tool_use_id),
|
||||
)
|
||||
|
||||
const keep = (b: {
|
||||
type: string
|
||||
id?: string
|
||||
tool_use_id?: string
|
||||
text?: string
|
||||
}) =>
|
||||
b.type !== 'thinking' &&
|
||||
b.type !== 'redacted_thinking' &&
|
||||
!(b.type === 'tool_use' && !toolIdsWithSuccessfulResults.has(b.id!)) &&
|
||||
!(
|
||||
b.type === 'tool_result' &&
|
||||
!toolIdsWithSuccessfulResults.has(b.tool_use_id!)
|
||||
) &&
|
||||
// Abort during speculation yields a standalone interrupt user message
|
||||
// (query.ts createUserInterruptionMessage). Strip it so it isn't surfaced
|
||||
// to the model as real user input.
|
||||
!(
|
||||
b.type === 'text' &&
|
||||
(b.text === INTERRUPT_MESSAGE ||
|
||||
b.text === INTERRUPT_MESSAGE_FOR_TOOL_USE)
|
||||
)
|
||||
|
||||
return messages
|
||||
.map(msg => {
|
||||
if (!('message' in msg) || !Array.isArray(msg.message.content)) return msg
|
||||
const content = msg.message.content.filter(keep)
|
||||
if (content.length === msg.message.content.length) return msg
|
||||
if (content.length === 0) return null
|
||||
// Drop messages where all remaining blocks are whitespace-only text
|
||||
// (API rejects these with 400: "text content blocks must contain non-whitespace text")
|
||||
const hasNonWhitespaceContent = content.some(
|
||||
(b: { type: string; text?: string }) =>
|
||||
b.type !== 'text' || (b.text !== undefined && b.text.trim() !== ''),
|
||||
)
|
||||
if (!hasNonWhitespaceContent) return null
|
||||
return { ...msg, message: { ...msg.message, content } } as typeof msg
|
||||
})
|
||||
.filter((m): m is Message => m !== null)
|
||||
}
|
||||
|
||||
function createSpeculationFeedbackMessage(
|
||||
messages: Message[],
|
||||
boundary: CompletionBoundary | null,
|
||||
timeSavedMs: number,
|
||||
sessionTotalMs: number,
|
||||
): Message | null {
|
||||
if (process.env.USER_TYPE !== 'ant') return null
|
||||
|
||||
if (messages.length === 0 || timeSavedMs === 0) return null
|
||||
|
||||
const toolUses = countToolsInMessages(messages)
|
||||
const tokens = boundary?.type === 'complete' ? boundary.outputTokens : null
|
||||
|
||||
const parts = []
|
||||
if (toolUses > 0) {
|
||||
parts.push(`Speculated ${toolUses} tool ${toolUses === 1 ? 'use' : 'uses'}`)
|
||||
} else {
|
||||
const turns = messages.length
|
||||
parts.push(`Speculated ${turns} ${turns === 1 ? 'turn' : 'turns'}`)
|
||||
}
|
||||
|
||||
if (tokens !== null) {
|
||||
parts.push(`${formatNumber(tokens)} tokens`)
|
||||
}
|
||||
|
||||
const savedText = `+${formatDuration(timeSavedMs)} saved`
|
||||
const sessionSuffix =
|
||||
sessionTotalMs !== timeSavedMs
|
||||
? ` (${formatDuration(sessionTotalMs)} this session)`
|
||||
: ''
|
||||
|
||||
return createSystemMessage(
|
||||
`[ANT-ONLY] ${parts.join(' · ')} · ${savedText}${sessionSuffix}`,
|
||||
'warning',
|
||||
)
|
||||
}
|
||||
|
||||
function updateActiveSpeculationState(
|
||||
setAppState: SetAppState,
|
||||
updater: (state: ActiveSpeculationState) => Partial<ActiveSpeculationState>,
|
||||
): void {
|
||||
setAppState(prev => {
|
||||
if (prev.speculation.status !== 'active') return prev
|
||||
const current = prev.speculation as ActiveSpeculationState
|
||||
const updates = updater(current)
|
||||
// Check if any values actually changed to avoid unnecessary re-renders
|
||||
const hasChanges = Object.entries(updates).some(
|
||||
([key, value]) => current[key as keyof ActiveSpeculationState] !== value,
|
||||
)
|
||||
if (!hasChanges) return prev
|
||||
return {
|
||||
...prev,
|
||||
speculation: { ...current, ...updates },
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
function resetSpeculationState(setAppState: SetAppState): void {
|
||||
setAppState(prev => {
|
||||
if (prev.speculation.status === 'idle') return prev
|
||||
return { ...prev, speculation: IDLE_SPECULATION_STATE }
|
||||
})
|
||||
}
|
||||
|
||||
export function isSpeculationEnabled(): boolean {
|
||||
const enabled =
|
||||
process.env.USER_TYPE === 'ant' &&
|
||||
(getGlobalConfig().speculationEnabled ?? true)
|
||||
logForDebugging(`[Speculation] enabled=${enabled}`)
|
||||
return enabled
|
||||
}
|
||||
|
||||
async function generatePipelinedSuggestion(
|
||||
context: REPLHookContext,
|
||||
suggestionText: string,
|
||||
speculatedMessages: Message[],
|
||||
setAppState: SetAppState,
|
||||
parentAbortController: AbortController,
|
||||
): Promise<void> {
|
||||
try {
|
||||
const appState = context.toolUseContext.getAppState()
|
||||
const suppressReason = getSuggestionSuppressReason(appState)
|
||||
if (suppressReason) {
|
||||
logSuggestionSuppressed(`pipeline_${suppressReason}`)
|
||||
return
|
||||
}
|
||||
|
||||
const augmentedContext: REPLHookContext = {
|
||||
...context,
|
||||
messages: [
|
||||
...context.messages,
|
||||
createUserMessage({ content: suggestionText }),
|
||||
...speculatedMessages,
|
||||
],
|
||||
}
|
||||
|
||||
const pipelineAbortController = createChildAbortController(
|
||||
parentAbortController,
|
||||
)
|
||||
if (pipelineAbortController.signal.aborted) return
|
||||
|
||||
const promptId = getPromptVariant()
|
||||
const { suggestion, generationRequestId } = await generateSuggestion(
|
||||
pipelineAbortController,
|
||||
promptId,
|
||||
createCacheSafeParams(augmentedContext),
|
||||
)
|
||||
|
||||
if (pipelineAbortController.signal.aborted) return
|
||||
if (shouldFilterSuggestion(suggestion, promptId)) return
|
||||
|
||||
logForDebugging(
|
||||
`[Speculation] Pipelined suggestion: "${suggestion!.slice(0, 50)}..."`,
|
||||
)
|
||||
updateActiveSpeculationState(setAppState, () => ({
|
||||
pipelinedSuggestion: {
|
||||
text: suggestion!,
|
||||
promptId,
|
||||
generationRequestId,
|
||||
},
|
||||
}))
|
||||
} catch (error) {
|
||||
if (error instanceof Error && error.name === 'AbortError') return
|
||||
logForDebugging(
|
||||
`[Speculation] Pipelined suggestion failed: ${errorMessage(error)}`,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
export async function startSpeculation(
|
||||
suggestionText: string,
|
||||
context: REPLHookContext,
|
||||
setAppState: (f: (prev: AppState) => AppState) => void,
|
||||
isPipelined = false,
|
||||
cacheSafeParams?: CacheSafeParams,
|
||||
): Promise<void> {
|
||||
if (!isSpeculationEnabled()) return
|
||||
|
||||
// Abort any existing speculation before starting a new one
|
||||
abortSpeculation(setAppState)
|
||||
|
||||
const id = randomUUID().slice(0, 8)
|
||||
|
||||
const abortController = createChildAbortController(
|
||||
context.toolUseContext.abortController,
|
||||
)
|
||||
|
||||
if (abortController.signal.aborted) return
|
||||
|
||||
const startTime = Date.now()
|
||||
const messagesRef = { current: [] as Message[] }
|
||||
const writtenPathsRef = { current: new Set<string>() }
|
||||
const overlayPath = getOverlayPath(id)
|
||||
const cwd = getCwdState()
|
||||
|
||||
try {
|
||||
await mkdir(overlayPath, { recursive: true })
|
||||
} catch {
|
||||
logForDebugging('[Speculation] Failed to create overlay directory')
|
||||
return
|
||||
}
|
||||
|
||||
const contextRef = { current: context }
|
||||
|
||||
setAppState(prev => ({
|
||||
...prev,
|
||||
speculation: {
|
||||
status: 'active',
|
||||
id,
|
||||
abort: () => abortController.abort(),
|
||||
startTime,
|
||||
messagesRef,
|
||||
writtenPathsRef,
|
||||
boundary: null,
|
||||
suggestionLength: suggestionText.length,
|
||||
toolUseCount: 0,
|
||||
isPipelined,
|
||||
contextRef,
|
||||
},
|
||||
}))
|
||||
|
||||
logForDebugging(`[Speculation] Starting speculation ${id}`)
|
||||
|
||||
try {
|
||||
const result = await runForkedAgent({
|
||||
promptMessages: [createUserMessage({ content: suggestionText })],
|
||||
cacheSafeParams: cacheSafeParams ?? createCacheSafeParams(context),
|
||||
skipTranscript: true,
|
||||
canUseTool: async (tool, input) => {
|
||||
const isWriteTool = WRITE_TOOLS.has(tool.name)
|
||||
const isSafeReadOnlyTool = SAFE_READ_ONLY_TOOLS.has(tool.name)
|
||||
|
||||
// Check permission mode BEFORE allowing file edits
|
||||
if (isWriteTool) {
|
||||
const appState = context.toolUseContext.getAppState()
|
||||
const { mode, isBypassPermissionsModeAvailable } =
|
||||
appState.toolPermissionContext
|
||||
|
||||
const canAutoAcceptEdits =
|
||||
mode === 'acceptEdits' ||
|
||||
mode === 'bypassPermissions' ||
|
||||
(mode === 'plan' && isBypassPermissionsModeAvailable)
|
||||
|
||||
if (!canAutoAcceptEdits) {
|
||||
logForDebugging(`[Speculation] Stopping at file edit: ${tool.name}`)
|
||||
const editPath = (
|
||||
'file_path' in input ? input.file_path : undefined
|
||||
) as string | undefined
|
||||
updateActiveSpeculationState(setAppState, () => ({
|
||||
boundary: {
|
||||
type: 'edit',
|
||||
toolName: tool.name,
|
||||
filePath: editPath ?? '',
|
||||
completedAt: Date.now(),
|
||||
},
|
||||
}))
|
||||
abortController.abort()
|
||||
return denySpeculation(
|
||||
'Speculation paused: file edit requires permission',
|
||||
'speculation_edit_boundary',
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle file path rewriting for overlay isolation
|
||||
if (isWriteTool || isSafeReadOnlyTool) {
|
||||
const pathKey =
|
||||
'notebook_path' in input
|
||||
? 'notebook_path'
|
||||
: 'path' in input
|
||||
? 'path'
|
||||
: 'file_path'
|
||||
const filePath = input[pathKey] as string | undefined
|
||||
if (filePath) {
|
||||
const rel = relative(cwd, filePath)
|
||||
if (isAbsolute(rel) || rel.startsWith('..')) {
|
||||
if (isWriteTool) {
|
||||
logForDebugging(
|
||||
`[Speculation] Denied ${tool.name}: path outside cwd: ${filePath}`,
|
||||
)
|
||||
return denySpeculation(
|
||||
'Write outside cwd not allowed during speculation',
|
||||
'speculation_write_outside_root',
|
||||
)
|
||||
}
|
||||
return {
|
||||
behavior: 'allow' as const,
|
||||
updatedInput: input,
|
||||
decisionReason: {
|
||||
type: 'other' as const,
|
||||
reason: 'speculation_read_outside_root',
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if (isWriteTool) {
|
||||
// Copy-on-write: copy original to overlay if not yet there
|
||||
if (!writtenPathsRef.current.has(rel)) {
|
||||
const overlayFile = join(overlayPath, rel)
|
||||
await mkdir(dirname(overlayFile), { recursive: true })
|
||||
try {
|
||||
await copyFile(join(cwd, rel), overlayFile)
|
||||
} catch {
|
||||
// Original may not exist (new file creation) - that's fine
|
||||
}
|
||||
writtenPathsRef.current.add(rel)
|
||||
}
|
||||
input = { ...input, [pathKey]: join(overlayPath, rel) }
|
||||
} else {
|
||||
// Read: redirect to overlay if file was previously written
|
||||
if (writtenPathsRef.current.has(rel)) {
|
||||
input = { ...input, [pathKey]: join(overlayPath, rel) }
|
||||
}
|
||||
// Otherwise read from main (no rewrite)
|
||||
}
|
||||
|
||||
logForDebugging(
|
||||
`[Speculation] ${isWriteTool ? 'Write' : 'Read'} ${filePath} -> ${input[pathKey]}`,
|
||||
)
|
||||
|
||||
return {
|
||||
behavior: 'allow' as const,
|
||||
updatedInput: input,
|
||||
decisionReason: {
|
||||
type: 'other' as const,
|
||||
reason: 'speculation_file_access',
|
||||
},
|
||||
}
|
||||
}
|
||||
// Read tools without explicit path (e.g. Glob/Grep defaulting to CWD) are safe
|
||||
if (isSafeReadOnlyTool) {
|
||||
return {
|
||||
behavior: 'allow' as const,
|
||||
updatedInput: input,
|
||||
decisionReason: {
|
||||
type: 'other' as const,
|
||||
reason: 'speculation_read_default_cwd',
|
||||
},
|
||||
}
|
||||
}
|
||||
// Write tools with undefined path → fall through to default deny
|
||||
}
|
||||
|
||||
// Stop at non-read-only bash commands
|
||||
if (tool.name === 'Bash') {
|
||||
const command =
|
||||
'command' in input && typeof input.command === 'string'
|
||||
? input.command
|
||||
: ''
|
||||
if (
|
||||
!command ||
|
||||
checkReadOnlyConstraints({ command }, commandHasAnyCd(command))
|
||||
.behavior !== 'allow'
|
||||
) {
|
||||
logForDebugging(
|
||||
`[Speculation] Stopping at bash: ${command.slice(0, 50) || 'missing command'}`,
|
||||
)
|
||||
updateActiveSpeculationState(setAppState, () => ({
|
||||
boundary: { type: 'bash', command, completedAt: Date.now() },
|
||||
}))
|
||||
abortController.abort()
|
||||
return denySpeculation(
|
||||
'Speculation paused: bash boundary',
|
||||
'speculation_bash_boundary',
|
||||
)
|
||||
}
|
||||
// Read-only bash command — allow during speculation
|
||||
return {
|
||||
behavior: 'allow' as const,
|
||||
updatedInput: input,
|
||||
decisionReason: {
|
||||
type: 'other' as const,
|
||||
reason: 'speculation_readonly_bash',
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Deny all other tools by default
|
||||
logForDebugging(`[Speculation] Stopping at denied tool: ${tool.name}`)
|
||||
const detail = String(
|
||||
('url' in input && input.url) ||
|
||||
('file_path' in input && input.file_path) ||
|
||||
('path' in input && input.path) ||
|
||||
('command' in input && input.command) ||
|
||||
'',
|
||||
).slice(0, 200)
|
||||
updateActiveSpeculationState(setAppState, () => ({
|
||||
boundary: {
|
||||
type: 'denied_tool',
|
||||
toolName: tool.name,
|
||||
detail,
|
||||
completedAt: Date.now(),
|
||||
},
|
||||
}))
|
||||
abortController.abort()
|
||||
return denySpeculation(
|
||||
`Tool ${tool.name} not allowed during speculation`,
|
||||
'speculation_unknown_tool',
|
||||
)
|
||||
},
|
||||
querySource: 'speculation',
|
||||
forkLabel: 'speculation',
|
||||
maxTurns: MAX_SPECULATION_TURNS,
|
||||
overrides: { abortController, requireCanUseTool: true },
|
||||
onMessage: msg => {
|
||||
if (msg.type === 'assistant' || msg.type === 'user') {
|
||||
messagesRef.current.push(msg)
|
||||
if (messagesRef.current.length >= MAX_SPECULATION_MESSAGES) {
|
||||
abortController.abort()
|
||||
}
|
||||
if (isUserMessageWithArrayContent(msg)) {
|
||||
const newTools = count(
|
||||
msg.message.content as { type: string; is_error?: boolean }[],
|
||||
b => b.type === 'tool_result' && !b.is_error,
|
||||
)
|
||||
if (newTools > 0) {
|
||||
updateActiveSpeculationState(setAppState, prev => ({
|
||||
toolUseCount: prev.toolUseCount + newTools,
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
if (abortController.signal.aborted) return
|
||||
|
||||
updateActiveSpeculationState(setAppState, () => ({
|
||||
boundary: {
|
||||
type: 'complete' as const,
|
||||
completedAt: Date.now(),
|
||||
outputTokens: result.totalUsage.output_tokens,
|
||||
},
|
||||
}))
|
||||
|
||||
logForDebugging(
|
||||
`[Speculation] Complete: ${countToolsInMessages(messagesRef.current)} tools`,
|
||||
)
|
||||
|
||||
// Pipeline: generate the next suggestion while we wait for the user to accept
|
||||
void generatePipelinedSuggestion(
|
||||
contextRef.current,
|
||||
suggestionText,
|
||||
messagesRef.current,
|
||||
setAppState,
|
||||
abortController,
|
||||
)
|
||||
} catch (error) {
|
||||
abortController.abort()
|
||||
|
||||
if (error instanceof Error && error.name === 'AbortError') {
|
||||
safeRemoveOverlay(overlayPath)
|
||||
resetSpeculationState(setAppState)
|
||||
return
|
||||
}
|
||||
|
||||
safeRemoveOverlay(overlayPath)
|
||||
|
||||
// eslint-disable-next-line no-restricted-syntax -- custom fallback message, not toError(e)
|
||||
logError(error instanceof Error ? error : new Error('Speculation failed'))
|
||||
|
||||
logSpeculation(
|
||||
id,
|
||||
'error',
|
||||
startTime,
|
||||
suggestionText.length,
|
||||
messagesRef.current,
|
||||
null,
|
||||
{
|
||||
error_type: error instanceof Error ? error.name : 'Unknown',
|
||||
error_message: errorMessage(error).slice(
|
||||
0,
|
||||
200,
|
||||
) as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
error_phase:
|
||||
'start' as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
is_pipelined: isPipelined,
|
||||
},
|
||||
)
|
||||
|
||||
resetSpeculationState(setAppState)
|
||||
}
|
||||
}
|
||||
|
||||
export async function acceptSpeculation(
|
||||
state: SpeculationState,
|
||||
setAppState: (f: (prev: AppState) => AppState) => void,
|
||||
cleanMessageCount: number,
|
||||
): Promise<SpeculationResult | null> {
|
||||
if (state.status !== 'active') return null
|
||||
|
||||
const {
|
||||
id,
|
||||
messagesRef,
|
||||
writtenPathsRef,
|
||||
abort,
|
||||
startTime,
|
||||
suggestionLength,
|
||||
isPipelined,
|
||||
} = state
|
||||
const messages = messagesRef.current
|
||||
const overlayPath = getOverlayPath(id)
|
||||
const acceptedAt = Date.now()
|
||||
|
||||
abort()
|
||||
|
||||
if (cleanMessageCount > 0) {
|
||||
await copyOverlayToMain(overlayPath, writtenPathsRef.current, getCwdState())
|
||||
}
|
||||
safeRemoveOverlay(overlayPath)
|
||||
|
||||
// Use snapshot boundary as default (available since state.status === 'active' was checked above)
|
||||
let boundary: CompletionBoundary | null = state.boundary
|
||||
let timeSavedMs =
|
||||
Math.min(acceptedAt, boundary?.completedAt ?? Infinity) - startTime
|
||||
|
||||
setAppState(prev => {
|
||||
// Refine with latest React state if speculation is still active
|
||||
if (prev.speculation.status === 'active' && prev.speculation.boundary) {
|
||||
boundary = prev.speculation.boundary
|
||||
const endTime = Math.min(acceptedAt, boundary.completedAt ?? Infinity)
|
||||
timeSavedMs = endTime - startTime
|
||||
}
|
||||
return {
|
||||
...prev,
|
||||
speculation: IDLE_SPECULATION_STATE,
|
||||
speculationSessionTimeSavedMs:
|
||||
prev.speculationSessionTimeSavedMs + timeSavedMs,
|
||||
}
|
||||
})
|
||||
|
||||
logForDebugging(
|
||||
boundary === null
|
||||
? `[Speculation] Accept ${id}: still running, using ${messages.length} messages`
|
||||
: `[Speculation] Accept ${id}: already complete`,
|
||||
)
|
||||
|
||||
logSpeculation(
|
||||
id,
|
||||
'accepted',
|
||||
startTime,
|
||||
suggestionLength,
|
||||
messages,
|
||||
boundary,
|
||||
{
|
||||
message_count: messages.length,
|
||||
time_saved_ms: timeSavedMs,
|
||||
is_pipelined: isPipelined,
|
||||
},
|
||||
)
|
||||
|
||||
if (timeSavedMs > 0) {
|
||||
const entry: SpeculationAcceptMessage = {
|
||||
type: 'speculation-accept',
|
||||
timestamp: new Date().toISOString(),
|
||||
timeSavedMs,
|
||||
}
|
||||
void appendFile(getTranscriptPath(), jsonStringify(entry) + '\n', {
|
||||
mode: 0o600,
|
||||
}).catch(() => {
|
||||
logForDebugging(
|
||||
'[Speculation] Failed to write speculation-accept to transcript',
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
return { messages, boundary, timeSavedMs }
|
||||
}
|
||||
|
||||
export function abortSpeculation(setAppState: SetAppState): void {
|
||||
setAppState(prev => {
|
||||
if (prev.speculation.status !== 'active') return prev
|
||||
|
||||
const {
|
||||
id,
|
||||
abort,
|
||||
startTime,
|
||||
boundary,
|
||||
suggestionLength,
|
||||
messagesRef,
|
||||
isPipelined,
|
||||
} = prev.speculation
|
||||
|
||||
logForDebugging(`[Speculation] Aborting ${id}`)
|
||||
|
||||
logSpeculation(
|
||||
id,
|
||||
'aborted',
|
||||
startTime,
|
||||
suggestionLength,
|
||||
messagesRef.current,
|
||||
boundary,
|
||||
{ abort_reason: 'user_typed', is_pipelined: isPipelined },
|
||||
)
|
||||
|
||||
abort()
|
||||
safeRemoveOverlay(getOverlayPath(id))
|
||||
|
||||
return { ...prev, speculation: IDLE_SPECULATION_STATE }
|
||||
})
|
||||
}
|
||||
|
||||
export async function handleSpeculationAccept(
|
||||
speculationState: ActiveSpeculationState,
|
||||
speculationSessionTimeSavedMs: number,
|
||||
setAppState: SetAppState,
|
||||
input: string,
|
||||
deps: {
|
||||
setMessages: (f: (prev: Message[]) => Message[]) => void
|
||||
readFileState: { current: FileStateCache }
|
||||
cwd: string
|
||||
},
|
||||
): Promise<{ queryRequired: boolean }> {
|
||||
try {
|
||||
const { setMessages, readFileState, cwd } = deps
|
||||
|
||||
// Clear prompt suggestion state. logOutcomeAtSubmission logged the accept
|
||||
// but was called with skipReset to avoid aborting speculation before we use it.
|
||||
setAppState(prev => {
|
||||
if (
|
||||
prev.promptSuggestion.text === null &&
|
||||
prev.promptSuggestion.promptId === null
|
||||
) {
|
||||
return prev
|
||||
}
|
||||
return {
|
||||
...prev,
|
||||
promptSuggestion: {
|
||||
text: null,
|
||||
promptId: null,
|
||||
shownAt: 0,
|
||||
acceptedAt: 0,
|
||||
generationRequestId: null,
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
// Capture speculation messages before any state updates - must be stable reference
|
||||
const speculationMessages = speculationState.messagesRef.current
|
||||
let cleanMessages = prepareMessagesForInjection(speculationMessages)
|
||||
|
||||
// Inject user message first for instant visual feedback before any async work
|
||||
const userMessage = createUserMessage({ content: input })
|
||||
setMessages(prev => [...prev, userMessage])
|
||||
|
||||
const result = await acceptSpeculation(
|
||||
speculationState,
|
||||
setAppState,
|
||||
cleanMessages.length,
|
||||
)
|
||||
|
||||
const isComplete = result?.boundary?.type === 'complete'
|
||||
|
||||
// When speculation didn't complete, the follow-up query needs the
|
||||
// conversation to end with a user message. Drop trailing assistant
|
||||
// messages — models that don't support prefill
|
||||
// reject conversations ending with an assistant turn. The model will
|
||||
// regenerate this content in the follow-up query.
|
||||
if (!isComplete) {
|
||||
const lastNonAssistant = cleanMessages.findLastIndex(
|
||||
m => m.type !== 'assistant',
|
||||
)
|
||||
cleanMessages = cleanMessages.slice(0, lastNonAssistant + 1)
|
||||
}
|
||||
|
||||
const timeSavedMs = result?.timeSavedMs ?? 0
|
||||
const newSessionTotal = speculationSessionTimeSavedMs + timeSavedMs
|
||||
const feedbackMessage = createSpeculationFeedbackMessage(
|
||||
cleanMessages,
|
||||
result?.boundary ?? null,
|
||||
timeSavedMs,
|
||||
newSessionTotal,
|
||||
)
|
||||
|
||||
// Inject speculated messages
|
||||
setMessages(prev => [...prev, ...cleanMessages])
|
||||
|
||||
const extracted = extractReadFilesFromMessages(
|
||||
cleanMessages,
|
||||
cwd,
|
||||
READ_FILE_STATE_CACHE_SIZE,
|
||||
)
|
||||
readFileState.current = mergeFileStateCaches(
|
||||
readFileState.current,
|
||||
extracted,
|
||||
)
|
||||
|
||||
if (feedbackMessage) {
|
||||
setMessages(prev => [...prev, feedbackMessage])
|
||||
}
|
||||
|
||||
logForDebugging(
|
||||
`[Speculation] ${result?.boundary?.type ?? 'incomplete'}, injected ${cleanMessages.length} messages`,
|
||||
)
|
||||
|
||||
// Promote pipelined suggestion if speculation completed fully
|
||||
if (isComplete && speculationState.pipelinedSuggestion) {
|
||||
const { text, promptId, generationRequestId } =
|
||||
speculationState.pipelinedSuggestion
|
||||
logForDebugging(
|
||||
`[Speculation] Promoting pipelined suggestion: "${text.slice(0, 50)}..."`,
|
||||
)
|
||||
setAppState(prev => ({
|
||||
...prev,
|
||||
promptSuggestion: {
|
||||
text,
|
||||
promptId,
|
||||
shownAt: Date.now(),
|
||||
acceptedAt: 0,
|
||||
generationRequestId,
|
||||
},
|
||||
}))
|
||||
|
||||
// Start speculation on the pipelined suggestion
|
||||
const augmentedContext: REPLHookContext = {
|
||||
...speculationState.contextRef.current,
|
||||
messages: [
|
||||
...speculationState.contextRef.current.messages,
|
||||
createUserMessage({ content: input }),
|
||||
...cleanMessages,
|
||||
],
|
||||
}
|
||||
void startSpeculation(text, augmentedContext, setAppState, true)
|
||||
}
|
||||
|
||||
return { queryRequired: !isComplete }
|
||||
} catch (error) {
|
||||
// Fail open: log error and fall back to normal query flow
|
||||
/* eslint-disable no-restricted-syntax -- custom fallback message, not toError(e) */
|
||||
logError(
|
||||
error instanceof Error
|
||||
? error
|
||||
: new Error('handleSpeculationAccept failed'),
|
||||
)
|
||||
/* eslint-enable no-restricted-syntax */
|
||||
logSpeculation(
|
||||
speculationState.id,
|
||||
'error',
|
||||
speculationState.startTime,
|
||||
speculationState.suggestionLength,
|
||||
speculationState.messagesRef.current,
|
||||
speculationState.boundary,
|
||||
{
|
||||
error_type: error instanceof Error ? error.name : 'Unknown',
|
||||
error_message: errorMessage(error).slice(
|
||||
0,
|
||||
200,
|
||||
) as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
error_phase:
|
||||
'accept' as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
is_pipelined: speculationState.isPipelined,
|
||||
},
|
||||
)
|
||||
safeRemoveOverlay(getOverlayPath(speculationState.id))
|
||||
resetSpeculationState(setAppState)
|
||||
// Query required so user's message is processed normally (without speculated work)
|
||||
return { queryRequired: true }
|
||||
}
|
||||
}
|
||||
324
src/services/SessionMemory/prompts.ts
Normal file
324
src/services/SessionMemory/prompts.ts
Normal file
@@ -0,0 +1,324 @@
|
||||
import { readFile } from 'fs/promises'
|
||||
import { join } from 'path'
|
||||
import { roughTokenCountEstimation } from '../../services/tokenEstimation.js'
|
||||
import { getClaudeConfigHomeDir } from '../../utils/envUtils.js'
|
||||
import { getErrnoCode, toError } from '../../utils/errors.js'
|
||||
import { logError } from '../../utils/log.js'
|
||||
|
||||
const MAX_SECTION_LENGTH = 2000
|
||||
const MAX_TOTAL_SESSION_MEMORY_TOKENS = 12000
|
||||
|
||||
export const DEFAULT_SESSION_MEMORY_TEMPLATE = `
|
||||
# Session Title
|
||||
_A short and distinctive 5-10 word descriptive title for the session. Super info dense, no filler_
|
||||
|
||||
# Current State
|
||||
_What is actively being worked on right now? Pending tasks not yet completed. Immediate next steps._
|
||||
|
||||
# Task specification
|
||||
_What did the user ask to build? Any design decisions or other explanatory context_
|
||||
|
||||
# Files and Functions
|
||||
_What are the important files? In short, what do they contain and why are they relevant?_
|
||||
|
||||
# Workflow
|
||||
_What bash commands are usually run and in what order? How to interpret their output if not obvious?_
|
||||
|
||||
# Errors & Corrections
|
||||
_Errors encountered and how they were fixed. What did the user correct? What approaches failed and should not be tried again?_
|
||||
|
||||
# Codebase and System Documentation
|
||||
_What are the important system components? How do they work/fit together?_
|
||||
|
||||
# Learnings
|
||||
_What has worked well? What has not? What to avoid? Do not duplicate items from other sections_
|
||||
|
||||
# Key results
|
||||
_If the user asked a specific output such as an answer to a question, a table, or other document, repeat the exact result here_
|
||||
|
||||
# Worklog
|
||||
_Step by step, what was attempted, done? Very terse summary for each step_
|
||||
`
|
||||
|
||||
function getDefaultUpdatePrompt(): string {
|
||||
return `IMPORTANT: This message and these instructions are NOT part of the actual user conversation. Do NOT include any references to "note-taking", "session notes extraction", or these update instructions in the notes content.
|
||||
|
||||
Based on the user conversation above (EXCLUDING this note-taking instruction message as well as system prompt, claude.md entries, or any past session summaries), update the session notes file.
|
||||
|
||||
The file {{notesPath}} has already been read for you. Here are its current contents:
|
||||
<current_notes_content>
|
||||
{{currentNotes}}
|
||||
</current_notes_content>
|
||||
|
||||
Your ONLY task is to use the Edit tool to update the notes file, then stop. You can make multiple edits (update every section as needed) - make all Edit tool calls in parallel in a single message. Do not call any other tools.
|
||||
|
||||
CRITICAL RULES FOR EDITING:
|
||||
- The file must maintain its exact structure with all sections, headers, and italic descriptions intact
|
||||
-- NEVER modify, delete, or add section headers (the lines starting with '#' like # Task specification)
|
||||
-- NEVER modify or delete the italic _section description_ lines (these are the lines in italics immediately following each header - they start and end with underscores)
|
||||
-- The italic _section descriptions_ are TEMPLATE INSTRUCTIONS that must be preserved exactly as-is - they guide what content belongs in each section
|
||||
-- ONLY update the actual content that appears BELOW the italic _section descriptions_ within each existing section
|
||||
-- Do NOT add any new sections, summaries, or information outside the existing structure
|
||||
- Do NOT reference this note-taking process or instructions anywhere in the notes
|
||||
- It's OK to skip updating a section if there are no substantial new insights to add. Do not add filler content like "No info yet", just leave sections blank/unedited if appropriate.
|
||||
- Write DETAILED, INFO-DENSE content for each section - include specifics like file paths, function names, error messages, exact commands, technical details, etc.
|
||||
- For "Key results", include the complete, exact output the user requested (e.g., full table, full answer, etc.)
|
||||
- Do not include information that's already in the CLAUDE.md files included in the context
|
||||
- Keep each section under ~${MAX_SECTION_LENGTH} tokens/words - if a section is approaching this limit, condense it by cycling out less important details while preserving the most critical information
|
||||
- Focus on actionable, specific information that would help someone understand or recreate the work discussed in the conversation
|
||||
- IMPORTANT: Always update "Current State" to reflect the most recent work - this is critical for continuity after compaction
|
||||
|
||||
Use the Edit tool with file_path: {{notesPath}}
|
||||
|
||||
STRUCTURE PRESERVATION REMINDER:
|
||||
Each section has TWO parts that must be preserved exactly as they appear in the current file:
|
||||
1. The section header (line starting with #)
|
||||
2. The italic description line (the _italicized text_ immediately after the header - this is a template instruction)
|
||||
|
||||
You ONLY update the actual content that comes AFTER these two preserved lines. The italic description lines starting and ending with underscores are part of the template structure, NOT content to be edited or removed.
|
||||
|
||||
REMEMBER: Use the Edit tool in parallel and stop. Do not continue after the edits. Only include insights from the actual user conversation, never from these note-taking instructions. Do not delete or change section headers or italic _section descriptions_.`
|
||||
}
|
||||
|
||||
/**
|
||||
* Load custom session memory template from file if it exists
|
||||
*/
|
||||
export async function loadSessionMemoryTemplate(): Promise<string> {
|
||||
const templatePath = join(
|
||||
getClaudeConfigHomeDir(),
|
||||
'session-memory',
|
||||
'config',
|
||||
'template.md',
|
||||
)
|
||||
|
||||
try {
|
||||
return await readFile(templatePath, { encoding: 'utf-8' })
|
||||
} catch (e: unknown) {
|
||||
const code = getErrnoCode(e)
|
||||
if (code === 'ENOENT') {
|
||||
return DEFAULT_SESSION_MEMORY_TEMPLATE
|
||||
}
|
||||
logError(toError(e))
|
||||
return DEFAULT_SESSION_MEMORY_TEMPLATE
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load custom session memory prompt from file if it exists
|
||||
* Custom prompts can be placed at ~/.claude/session-memory/prompt.md
|
||||
* Use {{variableName}} syntax for variable substitution (e.g., {{currentNotes}}, {{notesPath}})
|
||||
*/
|
||||
export async function loadSessionMemoryPrompt(): Promise<string> {
|
||||
const promptPath = join(
|
||||
getClaudeConfigHomeDir(),
|
||||
'session-memory',
|
||||
'config',
|
||||
'prompt.md',
|
||||
)
|
||||
|
||||
try {
|
||||
return await readFile(promptPath, { encoding: 'utf-8' })
|
||||
} catch (e: unknown) {
|
||||
const code = getErrnoCode(e)
|
||||
if (code === 'ENOENT') {
|
||||
return getDefaultUpdatePrompt()
|
||||
}
|
||||
logError(toError(e))
|
||||
return getDefaultUpdatePrompt()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse the session memory file and analyze section sizes
|
||||
*/
|
||||
function analyzeSectionSizes(content: string): Record<string, number> {
|
||||
const sections: Record<string, number> = {}
|
||||
const lines = content.split('\n')
|
||||
let currentSection = ''
|
||||
let currentContent: string[] = []
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('# ')) {
|
||||
if (currentSection && currentContent.length > 0) {
|
||||
const sectionContent = currentContent.join('\n').trim()
|
||||
sections[currentSection] = roughTokenCountEstimation(sectionContent)
|
||||
}
|
||||
currentSection = line
|
||||
currentContent = []
|
||||
} else {
|
||||
currentContent.push(line)
|
||||
}
|
||||
}
|
||||
|
||||
if (currentSection && currentContent.length > 0) {
|
||||
const sectionContent = currentContent.join('\n').trim()
|
||||
sections[currentSection] = roughTokenCountEstimation(sectionContent)
|
||||
}
|
||||
|
||||
return sections
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate reminders for sections that are too long
|
||||
*/
|
||||
function generateSectionReminders(
|
||||
sectionSizes: Record<string, number>,
|
||||
totalTokens: number,
|
||||
): string {
|
||||
const overBudget = totalTokens > MAX_TOTAL_SESSION_MEMORY_TOKENS
|
||||
const oversizedSections = Object.entries(sectionSizes)
|
||||
.filter(([_, tokens]) => tokens > MAX_SECTION_LENGTH)
|
||||
.sort(([, a], [, b]) => b - a)
|
||||
.map(
|
||||
([section, tokens]) =>
|
||||
`- "${section}" is ~${tokens} tokens (limit: ${MAX_SECTION_LENGTH})`,
|
||||
)
|
||||
|
||||
if (oversizedSections.length === 0 && !overBudget) {
|
||||
return ''
|
||||
}
|
||||
|
||||
const parts: string[] = []
|
||||
|
||||
if (overBudget) {
|
||||
parts.push(
|
||||
`\n\nCRITICAL: The session memory file is currently ~${totalTokens} tokens, which exceeds the maximum of ${MAX_TOTAL_SESSION_MEMORY_TOKENS} tokens. You MUST condense the file to fit within this budget. Aggressively shorten oversized sections by removing less important details, merging related items, and summarizing older entries. Prioritize keeping "Current State" and "Errors & Corrections" accurate and detailed.`,
|
||||
)
|
||||
}
|
||||
|
||||
if (oversizedSections.length > 0) {
|
||||
parts.push(
|
||||
`\n\n${overBudget ? 'Oversized sections to condense' : 'IMPORTANT: The following sections exceed the per-section limit and MUST be condensed'}:\n${oversizedSections.join('\n')}`,
|
||||
)
|
||||
}
|
||||
|
||||
return parts.join('')
|
||||
}
|
||||
|
||||
/**
|
||||
* Substitute variables in the prompt template using {{variable}} syntax
|
||||
*/
|
||||
function substituteVariables(
|
||||
template: string,
|
||||
variables: Record<string, string>,
|
||||
): string {
|
||||
// Single-pass replacement avoids two bugs: (1) $ backreference corruption
|
||||
// (replacer fn treats $ literally), and (2) double-substitution when user
|
||||
// content happens to contain {{varName}} matching a later variable.
|
||||
return template.replace(/\{\{(\w+)\}\}/g, (match, key: string) =>
|
||||
Object.prototype.hasOwnProperty.call(variables, key)
|
||||
? variables[key]!
|
||||
: match,
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the session memory content is essentially empty (matches the template).
|
||||
* This is used to detect if no actual content has been extracted yet,
|
||||
* which means we should fall back to legacy compact behavior.
|
||||
*/
|
||||
export async function isSessionMemoryEmpty(content: string): Promise<boolean> {
|
||||
const template = await loadSessionMemoryTemplate()
|
||||
// Compare trimmed content to detect if it's just the template
|
||||
return content.trim() === template.trim()
|
||||
}
|
||||
|
||||
export async function buildSessionMemoryUpdatePrompt(
|
||||
currentNotes: string,
|
||||
notesPath: string,
|
||||
): Promise<string> {
|
||||
const promptTemplate = await loadSessionMemoryPrompt()
|
||||
|
||||
// Analyze section sizes and generate reminders if needed
|
||||
const sectionSizes = analyzeSectionSizes(currentNotes)
|
||||
const totalTokens = roughTokenCountEstimation(currentNotes)
|
||||
const sectionReminders = generateSectionReminders(sectionSizes, totalTokens)
|
||||
|
||||
// Substitute variables in the prompt
|
||||
const variables = {
|
||||
currentNotes,
|
||||
notesPath,
|
||||
}
|
||||
|
||||
const basePrompt = substituteVariables(promptTemplate, variables)
|
||||
|
||||
// Add section size reminders and/or total budget warnings
|
||||
return basePrompt + sectionReminders
|
||||
}
|
||||
|
||||
/**
|
||||
* Truncate session memory sections that exceed the per-section token limit.
|
||||
* Used when inserting session memory into compact messages to prevent
|
||||
* oversized session memory from consuming the entire post-compact token budget.
|
||||
*
|
||||
* Returns the truncated content and whether any truncation occurred.
|
||||
*/
|
||||
export function truncateSessionMemoryForCompact(content: string): {
|
||||
truncatedContent: string
|
||||
wasTruncated: boolean
|
||||
} {
|
||||
const lines = content.split('\n')
|
||||
const maxCharsPerSection = MAX_SECTION_LENGTH * 4 // roughTokenCountEstimation uses length/4
|
||||
const outputLines: string[] = []
|
||||
let currentSectionLines: string[] = []
|
||||
let currentSectionHeader = ''
|
||||
let wasTruncated = false
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('# ')) {
|
||||
const result = flushSessionSection(
|
||||
currentSectionHeader,
|
||||
currentSectionLines,
|
||||
maxCharsPerSection,
|
||||
)
|
||||
outputLines.push(...result.lines)
|
||||
wasTruncated = wasTruncated || result.wasTruncated
|
||||
currentSectionHeader = line
|
||||
currentSectionLines = []
|
||||
} else {
|
||||
currentSectionLines.push(line)
|
||||
}
|
||||
}
|
||||
|
||||
// Flush the last section
|
||||
const result = flushSessionSection(
|
||||
currentSectionHeader,
|
||||
currentSectionLines,
|
||||
maxCharsPerSection,
|
||||
)
|
||||
outputLines.push(...result.lines)
|
||||
wasTruncated = wasTruncated || result.wasTruncated
|
||||
|
||||
return {
|
||||
truncatedContent: outputLines.join('\n'),
|
||||
wasTruncated,
|
||||
}
|
||||
}
|
||||
|
||||
function flushSessionSection(
|
||||
sectionHeader: string,
|
||||
sectionLines: string[],
|
||||
maxCharsPerSection: number,
|
||||
): { lines: string[]; wasTruncated: boolean } {
|
||||
if (!sectionHeader) {
|
||||
return { lines: sectionLines, wasTruncated: false }
|
||||
}
|
||||
|
||||
const sectionContent = sectionLines.join('\n')
|
||||
if (sectionContent.length <= maxCharsPerSection) {
|
||||
return { lines: [sectionHeader, ...sectionLines], wasTruncated: false }
|
||||
}
|
||||
|
||||
// Truncate at a line boundary near the limit
|
||||
let charCount = 0
|
||||
const keptLines: string[] = [sectionHeader]
|
||||
for (const line of sectionLines) {
|
||||
if (charCount + line.length + 1 > maxCharsPerSection) {
|
||||
break
|
||||
}
|
||||
keptLines.push(line)
|
||||
charCount += line.length + 1
|
||||
}
|
||||
keptLines.push('\n[... section truncated for length ...]')
|
||||
return { lines: keptLines, wasTruncated: true }
|
||||
}
|
||||
495
src/services/SessionMemory/sessionMemory.ts
Normal file
495
src/services/SessionMemory/sessionMemory.ts
Normal file
@@ -0,0 +1,495 @@
|
||||
/**
|
||||
* Session Memory automatically maintains a markdown file with notes about the current conversation.
|
||||
* It runs periodically in the background using a forked subagent to extract key information
|
||||
* without interrupting the main conversation flow.
|
||||
*/
|
||||
|
||||
import { writeFile } from 'fs/promises'
|
||||
import memoize from 'lodash-es/memoize.js'
|
||||
import { getIsRemoteMode } from '../../bootstrap/state.js'
|
||||
import { getSystemPrompt } from '../../constants/prompts.js'
|
||||
import { getSystemContext, getUserContext } from '../../context.js'
|
||||
import type { CanUseToolFn } from '../../hooks/useCanUseTool.js'
|
||||
import type { Tool, ToolUseContext } from '../../Tool.js'
|
||||
import { FILE_EDIT_TOOL_NAME } from '../../tools/FileEditTool/constants.js'
|
||||
import {
|
||||
FileReadTool,
|
||||
type Output as FileReadToolOutput,
|
||||
} from '../../tools/FileReadTool/FileReadTool.js'
|
||||
import type { Message } from '../../types/message.js'
|
||||
import { count } from '../../utils/array.js'
|
||||
import {
|
||||
createCacheSafeParams,
|
||||
createSubagentContext,
|
||||
runForkedAgent,
|
||||
} from '../../utils/forkedAgent.js'
|
||||
import { getFsImplementation } from '../../utils/fsOperations.js'
|
||||
import {
|
||||
type REPLHookContext,
|
||||
registerPostSamplingHook,
|
||||
} from '../../utils/hooks/postSamplingHooks.js'
|
||||
import {
|
||||
createUserMessage,
|
||||
hasToolCallsInLastAssistantTurn,
|
||||
} from '../../utils/messages.js'
|
||||
import {
|
||||
getSessionMemoryDir,
|
||||
getSessionMemoryPath,
|
||||
} from '../../utils/permissions/filesystem.js'
|
||||
import { sequential } from '../../utils/sequential.js'
|
||||
import { asSystemPrompt } from '../../utils/systemPromptType.js'
|
||||
import { getTokenUsage, tokenCountWithEstimation } from '../../utils/tokens.js'
|
||||
import { logEvent } from '../analytics/index.js'
|
||||
import { isAutoCompactEnabled } from '../compact/autoCompact.js'
|
||||
import {
|
||||
buildSessionMemoryUpdatePrompt,
|
||||
loadSessionMemoryTemplate,
|
||||
} from './prompts.js'
|
||||
import {
|
||||
DEFAULT_SESSION_MEMORY_CONFIG,
|
||||
getSessionMemoryConfig,
|
||||
getToolCallsBetweenUpdates,
|
||||
hasMetInitializationThreshold,
|
||||
hasMetUpdateThreshold,
|
||||
isSessionMemoryInitialized,
|
||||
markExtractionCompleted,
|
||||
markExtractionStarted,
|
||||
markSessionMemoryInitialized,
|
||||
recordExtractionTokenCount,
|
||||
type SessionMemoryConfig,
|
||||
setLastSummarizedMessageId,
|
||||
setSessionMemoryConfig,
|
||||
} from './sessionMemoryUtils.js'
|
||||
|
||||
// ============================================================================
|
||||
// Feature Gate and Config (Cached - Non-blocking)
|
||||
// ============================================================================
|
||||
// These functions return cached values from disk immediately without blocking
|
||||
// on GrowthBook initialization. Values may be stale but are updated in background.
|
||||
|
||||
import { errorMessage, getErrnoCode } from '../../utils/errors.js'
|
||||
import {
|
||||
getDynamicConfig_CACHED_MAY_BE_STALE,
|
||||
getFeatureValue_CACHED_MAY_BE_STALE,
|
||||
} from '../analytics/growthbook.js'
|
||||
|
||||
/**
|
||||
* Check if session memory feature is enabled.
|
||||
* Uses cached gate value - returns immediately without blocking.
|
||||
*/
|
||||
function isSessionMemoryGateEnabled(): boolean {
|
||||
return getFeatureValue_CACHED_MAY_BE_STALE('tengu_session_memory', false)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get session memory config from cache.
|
||||
* Returns immediately without blocking - value may be stale.
|
||||
*/
|
||||
function getSessionMemoryRemoteConfig(): Partial<SessionMemoryConfig> {
|
||||
return getDynamicConfig_CACHED_MAY_BE_STALE<Partial<SessionMemoryConfig>>(
|
||||
'tengu_sm_config',
|
||||
{},
|
||||
)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Module State
|
||||
// ============================================================================
|
||||
|
||||
let lastMemoryMessageUuid: string | undefined
|
||||
|
||||
/**
|
||||
* Reset the last memory message UUID (for testing)
|
||||
*/
|
||||
export function resetLastMemoryMessageUuid(): void {
|
||||
lastMemoryMessageUuid = undefined
|
||||
}
|
||||
|
||||
function countToolCallsSince(
|
||||
messages: Message[],
|
||||
sinceUuid: string | undefined,
|
||||
): number {
|
||||
let toolCallCount = 0
|
||||
let foundStart = sinceUuid === null || sinceUuid === undefined
|
||||
|
||||
for (const message of messages) {
|
||||
if (!foundStart) {
|
||||
if (message.uuid === sinceUuid) {
|
||||
foundStart = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if (message.type === 'assistant') {
|
||||
const content = message.message.content
|
||||
if (Array.isArray(content)) {
|
||||
toolCallCount += count(content, block => block.type === 'tool_use')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return toolCallCount
|
||||
}
|
||||
|
||||
export function shouldExtractMemory(messages: Message[]): boolean {
|
||||
// Check if we've met the initialization threshold
|
||||
// Uses total context window tokens (same as autocompact) for consistent behavior
|
||||
const currentTokenCount = tokenCountWithEstimation(messages)
|
||||
if (!isSessionMemoryInitialized()) {
|
||||
if (!hasMetInitializationThreshold(currentTokenCount)) {
|
||||
return false
|
||||
}
|
||||
markSessionMemoryInitialized()
|
||||
}
|
||||
|
||||
// Check if we've met the minimum tokens between updates threshold
|
||||
// Uses context window growth since last extraction (same metric as init threshold)
|
||||
const hasMetTokenThreshold = hasMetUpdateThreshold(currentTokenCount)
|
||||
|
||||
// Check if we've met the tool calls threshold
|
||||
const toolCallsSinceLastUpdate = countToolCallsSince(
|
||||
messages,
|
||||
lastMemoryMessageUuid,
|
||||
)
|
||||
const hasMetToolCallThreshold =
|
||||
toolCallsSinceLastUpdate >= getToolCallsBetweenUpdates()
|
||||
|
||||
// Check if the last assistant turn has no tool calls (safe to extract)
|
||||
const hasToolCallsInLastTurn = hasToolCallsInLastAssistantTurn(messages)
|
||||
|
||||
// Trigger extraction when:
|
||||
// 1. Both thresholds are met (tokens AND tool calls), OR
|
||||
// 2. No tool calls in last turn AND token threshold is met
|
||||
// (to ensure we extract at natural conversation breaks)
|
||||
//
|
||||
// IMPORTANT: The token threshold (minimumTokensBetweenUpdate) is ALWAYS required.
|
||||
// Even if the tool call threshold is met, extraction won't happen until the
|
||||
// token threshold is also satisfied. This prevents excessive extractions.
|
||||
const shouldExtract =
|
||||
(hasMetTokenThreshold && hasMetToolCallThreshold) ||
|
||||
(hasMetTokenThreshold && !hasToolCallsInLastTurn)
|
||||
|
||||
if (shouldExtract) {
|
||||
const lastMessage = messages[messages.length - 1]
|
||||
if (lastMessage?.uuid) {
|
||||
lastMemoryMessageUuid = lastMessage.uuid
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
async function setupSessionMemoryFile(
|
||||
toolUseContext: ToolUseContext,
|
||||
): Promise<{ memoryPath: string; currentMemory: string }> {
|
||||
const fs = getFsImplementation()
|
||||
|
||||
// Set up directory and file
|
||||
const sessionMemoryDir = getSessionMemoryDir()
|
||||
await fs.mkdir(sessionMemoryDir, { mode: 0o700 })
|
||||
|
||||
const memoryPath = getSessionMemoryPath()
|
||||
|
||||
// Create the memory file if it doesn't exist (wx = O_CREAT|O_EXCL)
|
||||
try {
|
||||
await writeFile(memoryPath, '', {
|
||||
encoding: 'utf-8',
|
||||
mode: 0o600,
|
||||
flag: 'wx',
|
||||
})
|
||||
// Only load template if file was just created
|
||||
const template = await loadSessionMemoryTemplate()
|
||||
await writeFile(memoryPath, template, {
|
||||
encoding: 'utf-8',
|
||||
mode: 0o600,
|
||||
})
|
||||
} catch (e: unknown) {
|
||||
const code = getErrnoCode(e)
|
||||
if (code !== 'EEXIST') {
|
||||
throw e
|
||||
}
|
||||
}
|
||||
|
||||
// Drop any cached entry so FileReadTool's dedup doesn't return a
|
||||
// file_unchanged stub — we need the actual content. The Read repopulates it.
|
||||
toolUseContext.readFileState.delete(memoryPath)
|
||||
const result = await FileReadTool.call(
|
||||
{ file_path: memoryPath },
|
||||
toolUseContext,
|
||||
)
|
||||
let currentMemory = ''
|
||||
|
||||
const output = result.data as FileReadToolOutput
|
||||
if (output.type === 'text') {
|
||||
currentMemory = output.file.content
|
||||
}
|
||||
|
||||
logEvent('tengu_session_memory_file_read', {
|
||||
content_length: currentMemory.length,
|
||||
})
|
||||
|
||||
return { memoryPath, currentMemory }
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize session memory config from remote config (lazy initialization).
|
||||
* Memoized - only runs once per session, subsequent calls return immediately.
|
||||
* Uses cached config values - non-blocking.
|
||||
*/
|
||||
const initSessionMemoryConfigIfNeeded = memoize((): void => {
|
||||
// Load config from cache (non-blocking, may be stale)
|
||||
const remoteConfig = getSessionMemoryRemoteConfig()
|
||||
|
||||
// Only use remote values if they are explicitly set (non-zero positive numbers)
|
||||
// This ensures sensible defaults aren't overridden by zero values
|
||||
const config: SessionMemoryConfig = {
|
||||
minimumMessageTokensToInit:
|
||||
remoteConfig.minimumMessageTokensToInit &&
|
||||
remoteConfig.minimumMessageTokensToInit > 0
|
||||
? remoteConfig.minimumMessageTokensToInit
|
||||
: DEFAULT_SESSION_MEMORY_CONFIG.minimumMessageTokensToInit,
|
||||
minimumTokensBetweenUpdate:
|
||||
remoteConfig.minimumTokensBetweenUpdate &&
|
||||
remoteConfig.minimumTokensBetweenUpdate > 0
|
||||
? remoteConfig.minimumTokensBetweenUpdate
|
||||
: DEFAULT_SESSION_MEMORY_CONFIG.minimumTokensBetweenUpdate,
|
||||
toolCallsBetweenUpdates:
|
||||
remoteConfig.toolCallsBetweenUpdates &&
|
||||
remoteConfig.toolCallsBetweenUpdates > 0
|
||||
? remoteConfig.toolCallsBetweenUpdates
|
||||
: DEFAULT_SESSION_MEMORY_CONFIG.toolCallsBetweenUpdates,
|
||||
}
|
||||
setSessionMemoryConfig(config)
|
||||
})
|
||||
|
||||
/**
|
||||
* Session memory post-sampling hook that extracts and updates session notes
|
||||
*/
|
||||
// Track if we've logged the gate check failure this session (to avoid spam)
|
||||
let hasLoggedGateFailure = false
|
||||
|
||||
const extractSessionMemory = sequential(async function (
|
||||
context: REPLHookContext,
|
||||
): Promise<void> {
|
||||
const { messages, toolUseContext, querySource } = context
|
||||
|
||||
// Only run session memory on main REPL thread
|
||||
if (querySource !== 'repl_main_thread') {
|
||||
// Don't log this - it's expected for subagents, teammates, etc.
|
||||
return
|
||||
}
|
||||
|
||||
// Check gate lazily when hook runs (cached, non-blocking)
|
||||
if (!isSessionMemoryGateEnabled()) {
|
||||
// Log gate failure once per session (ant-only)
|
||||
if (process.env.USER_TYPE === 'ant' && !hasLoggedGateFailure) {
|
||||
hasLoggedGateFailure = true
|
||||
logEvent('tengu_session_memory_gate_disabled', {})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Initialize config from remote (lazy, only once)
|
||||
initSessionMemoryConfigIfNeeded()
|
||||
|
||||
if (!shouldExtractMemory(messages)) {
|
||||
return
|
||||
}
|
||||
|
||||
markExtractionStarted()
|
||||
|
||||
// Create isolated context for setup to avoid polluting parent's cache
|
||||
const setupContext = createSubagentContext(toolUseContext)
|
||||
|
||||
// Set up file system and read current state with isolated context
|
||||
const { memoryPath, currentMemory } =
|
||||
await setupSessionMemoryFile(setupContext)
|
||||
|
||||
// Create extraction message
|
||||
const userPrompt = await buildSessionMemoryUpdatePrompt(
|
||||
currentMemory,
|
||||
memoryPath,
|
||||
)
|
||||
|
||||
// Run session memory extraction using runForkedAgent for prompt caching
|
||||
// runForkedAgent creates an isolated context to prevent mutation of parent state
|
||||
// Pass setupContext.readFileState so the forked agent can edit the memory file
|
||||
await runForkedAgent({
|
||||
promptMessages: [createUserMessage({ content: userPrompt })],
|
||||
cacheSafeParams: createCacheSafeParams(context),
|
||||
canUseTool: createMemoryFileCanUseTool(memoryPath),
|
||||
querySource: 'session_memory',
|
||||
forkLabel: 'session_memory',
|
||||
overrides: { readFileState: setupContext.readFileState },
|
||||
})
|
||||
|
||||
// Log extraction event for tracking frequency
|
||||
// Use the token usage from the last message in the conversation
|
||||
const lastMessage = messages[messages.length - 1]
|
||||
const usage = lastMessage ? getTokenUsage(lastMessage) : undefined
|
||||
const config = getSessionMemoryConfig()
|
||||
logEvent('tengu_session_memory_extraction', {
|
||||
input_tokens: usage?.input_tokens,
|
||||
output_tokens: usage?.output_tokens,
|
||||
cache_read_input_tokens: usage?.cache_read_input_tokens ?? undefined,
|
||||
cache_creation_input_tokens:
|
||||
usage?.cache_creation_input_tokens ?? undefined,
|
||||
config_min_message_tokens_to_init: config.minimumMessageTokensToInit,
|
||||
config_min_tokens_between_update: config.minimumTokensBetweenUpdate,
|
||||
config_tool_calls_between_updates: config.toolCallsBetweenUpdates,
|
||||
})
|
||||
|
||||
// Record the context size at extraction for tracking minimumTokensBetweenUpdate
|
||||
recordExtractionTokenCount(tokenCountWithEstimation(messages))
|
||||
|
||||
// Update lastSummarizedMessageId after successful completion
|
||||
updateLastSummarizedMessageIdIfSafe(messages)
|
||||
|
||||
markExtractionCompleted()
|
||||
})
|
||||
|
||||
/**
|
||||
* Initialize session memory by registering the post-sampling hook.
|
||||
* This is synchronous to avoid race conditions during startup.
|
||||
* The gate check and config loading happen lazily when the hook runs.
|
||||
*/
|
||||
export function initSessionMemory(): void {
|
||||
if (getIsRemoteMode()) return
|
||||
// Session memory is used for compaction, so respect auto-compact settings
|
||||
const autoCompactEnabled = isAutoCompactEnabled()
|
||||
|
||||
// Log initialization state (ant-only to avoid noise in external logs)
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
logEvent('tengu_session_memory_init', {
|
||||
auto_compact_enabled: autoCompactEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
if (!autoCompactEnabled) {
|
||||
return
|
||||
}
|
||||
|
||||
// Register hook unconditionally - gate check happens lazily when hook runs
|
||||
registerPostSamplingHook(extractSessionMemory)
|
||||
}
|
||||
|
||||
export type ManualExtractionResult = {
|
||||
success: boolean
|
||||
memoryPath?: string
|
||||
error?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Manually trigger session memory extraction, bypassing threshold checks.
|
||||
* Used by the /summary command.
|
||||
*/
|
||||
export async function manuallyExtractSessionMemory(
|
||||
messages: Message[],
|
||||
toolUseContext: ToolUseContext,
|
||||
): Promise<ManualExtractionResult> {
|
||||
if (messages.length === 0) {
|
||||
return { success: false, error: 'No messages to summarize' }
|
||||
}
|
||||
markExtractionStarted()
|
||||
|
||||
try {
|
||||
// Create isolated context for setup to avoid polluting parent's cache
|
||||
const setupContext = createSubagentContext(toolUseContext)
|
||||
|
||||
// Set up file system and read current state with isolated context
|
||||
const { memoryPath, currentMemory } =
|
||||
await setupSessionMemoryFile(setupContext)
|
||||
|
||||
// Create extraction message
|
||||
const userPrompt = await buildSessionMemoryUpdatePrompt(
|
||||
currentMemory,
|
||||
memoryPath,
|
||||
)
|
||||
|
||||
// Get system prompt for cache-safe params
|
||||
const { tools, mainLoopModel } = toolUseContext.options
|
||||
const [rawSystemPrompt, userContext, systemContext] = await Promise.all([
|
||||
getSystemPrompt(tools, mainLoopModel),
|
||||
getUserContext(),
|
||||
getSystemContext(),
|
||||
])
|
||||
const systemPrompt = asSystemPrompt(rawSystemPrompt)
|
||||
|
||||
// Run session memory extraction using runForkedAgent
|
||||
await runForkedAgent({
|
||||
promptMessages: [createUserMessage({ content: userPrompt })],
|
||||
cacheSafeParams: {
|
||||
systemPrompt,
|
||||
userContext,
|
||||
systemContext,
|
||||
toolUseContext: setupContext,
|
||||
forkContextMessages: messages,
|
||||
},
|
||||
canUseTool: createMemoryFileCanUseTool(memoryPath),
|
||||
querySource: 'session_memory',
|
||||
forkLabel: 'session_memory_manual',
|
||||
overrides: { readFileState: setupContext.readFileState },
|
||||
})
|
||||
|
||||
// Log manual extraction event
|
||||
logEvent('tengu_session_memory_manual_extraction', {})
|
||||
|
||||
// Record the context size at extraction for tracking minimumTokensBetweenUpdate
|
||||
recordExtractionTokenCount(tokenCountWithEstimation(messages))
|
||||
|
||||
// Update lastSummarizedMessageId after successful completion
|
||||
updateLastSummarizedMessageIdIfSafe(messages)
|
||||
|
||||
return { success: true, memoryPath }
|
||||
} catch (error) {
|
||||
return {
|
||||
success: false,
|
||||
error: errorMessage(error),
|
||||
}
|
||||
} finally {
|
||||
markExtractionCompleted()
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
/**
|
||||
* Creates a canUseTool function that only allows Edit for the exact memory file.
|
||||
*/
|
||||
export function createMemoryFileCanUseTool(memoryPath: string): CanUseToolFn {
|
||||
return async (tool: Tool, input: unknown) => {
|
||||
if (
|
||||
tool.name === FILE_EDIT_TOOL_NAME &&
|
||||
typeof input === 'object' &&
|
||||
input !== null &&
|
||||
'file_path' in input
|
||||
) {
|
||||
const filePath = input.file_path
|
||||
if (typeof filePath === 'string' && filePath === memoryPath) {
|
||||
return { behavior: 'allow' as const, updatedInput: input }
|
||||
}
|
||||
}
|
||||
return {
|
||||
behavior: 'deny' as const,
|
||||
message: `only ${FILE_EDIT_TOOL_NAME} on ${memoryPath} is allowed`,
|
||||
decisionReason: {
|
||||
type: 'other' as const,
|
||||
reason: `only ${FILE_EDIT_TOOL_NAME} on ${memoryPath} is allowed`,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates lastSummarizedMessageId after successful extraction.
|
||||
* Only sets it if the last message doesn't have tool calls (to avoid orphaned tool_results).
|
||||
*/
|
||||
function updateLastSummarizedMessageIdIfSafe(messages: Message[]): void {
|
||||
if (!hasToolCallsInLastAssistantTurn(messages)) {
|
||||
const lastMessage = messages[messages.length - 1]
|
||||
if (lastMessage?.uuid) {
|
||||
setLastSummarizedMessageId(lastMessage.uuid)
|
||||
}
|
||||
}
|
||||
}
|
||||
207
src/services/SessionMemory/sessionMemoryUtils.ts
Normal file
207
src/services/SessionMemory/sessionMemoryUtils.ts
Normal file
@@ -0,0 +1,207 @@
|
||||
/**
|
||||
* Session Memory utility functions that can be imported without circular dependencies.
|
||||
* These are separate from the main sessionMemory.ts to avoid importing runAgent.
|
||||
*/
|
||||
|
||||
import { isFsInaccessible } from '../../utils/errors.js'
|
||||
import { getFsImplementation } from '../../utils/fsOperations.js'
|
||||
import { getSessionMemoryPath } from '../../utils/permissions/filesystem.js'
|
||||
import { sleep } from '../../utils/sleep.js'
|
||||
import { logEvent } from '../analytics/index.js'
|
||||
|
||||
const EXTRACTION_WAIT_TIMEOUT_MS = 15000
|
||||
const EXTRACTION_STALE_THRESHOLD_MS = 60000 // 1 minute
|
||||
|
||||
/**
|
||||
* Configuration for session memory extraction thresholds
|
||||
*/
|
||||
export type SessionMemoryConfig = {
|
||||
/** Minimum context window tokens before initializing session memory.
|
||||
* Uses the same token counting as autocompact (input + output + cache tokens)
|
||||
* to ensure consistent behavior between the two features. */
|
||||
minimumMessageTokensToInit: number
|
||||
/** Minimum context window growth (in tokens) between session memory updates.
|
||||
* Uses the same token counting as autocompact (tokenCountWithEstimation)
|
||||
* to measure actual context growth, not cumulative API usage. */
|
||||
minimumTokensBetweenUpdate: number
|
||||
/** Number of tool calls between session memory updates */
|
||||
toolCallsBetweenUpdates: number
|
||||
}
|
||||
|
||||
// Default configuration values
|
||||
export const DEFAULT_SESSION_MEMORY_CONFIG: SessionMemoryConfig = {
|
||||
minimumMessageTokensToInit: 10000,
|
||||
minimumTokensBetweenUpdate: 5000,
|
||||
toolCallsBetweenUpdates: 3,
|
||||
}
|
||||
|
||||
// Current session memory configuration
|
||||
let sessionMemoryConfig: SessionMemoryConfig = {
|
||||
...DEFAULT_SESSION_MEMORY_CONFIG,
|
||||
}
|
||||
|
||||
// Track the last summarized message ID (shared state)
|
||||
let lastSummarizedMessageId: string | undefined
|
||||
|
||||
// Track extraction state with timestamp (set by sessionMemory.ts)
|
||||
let extractionStartedAt: number | undefined
|
||||
|
||||
// Track context size at last memory extraction (for minimumTokensBetweenUpdate)
|
||||
let tokensAtLastExtraction = 0
|
||||
|
||||
// Track whether session memory has been initialized (met minimumMessageTokensToInit)
|
||||
let sessionMemoryInitialized = false
|
||||
|
||||
/**
|
||||
* Get the message ID up to which the session memory is current
|
||||
*/
|
||||
export function getLastSummarizedMessageId(): string | undefined {
|
||||
return lastSummarizedMessageId
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the last summarized message ID (called from sessionMemory.ts)
|
||||
*/
|
||||
export function setLastSummarizedMessageId(
|
||||
messageId: string | undefined,
|
||||
): void {
|
||||
lastSummarizedMessageId = messageId
|
||||
}
|
||||
|
||||
/**
|
||||
* Mark extraction as started (called from sessionMemory.ts)
|
||||
*/
|
||||
export function markExtractionStarted(): void {
|
||||
extractionStartedAt = Date.now()
|
||||
}
|
||||
|
||||
/**
|
||||
* Mark extraction as completed (called from sessionMemory.ts)
|
||||
*/
|
||||
export function markExtractionCompleted(): void {
|
||||
extractionStartedAt = undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* Wait for any in-progress session memory extraction to complete (with 15s timeout)
|
||||
* Returns immediately if no extraction is in progress or if extraction is stale (>1min old).
|
||||
*/
|
||||
export async function waitForSessionMemoryExtraction(): Promise<void> {
|
||||
const startTime = Date.now()
|
||||
while (extractionStartedAt) {
|
||||
const extractionAge = Date.now() - extractionStartedAt
|
||||
if (extractionAge > EXTRACTION_STALE_THRESHOLD_MS) {
|
||||
// Extraction is stale, don't wait
|
||||
return
|
||||
}
|
||||
|
||||
if (Date.now() - startTime > EXTRACTION_WAIT_TIMEOUT_MS) {
|
||||
// Timeout - continue anyway
|
||||
return
|
||||
}
|
||||
|
||||
await sleep(1000)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current session memory content
|
||||
*/
|
||||
export async function getSessionMemoryContent(): Promise<string | null> {
|
||||
const fs = getFsImplementation()
|
||||
const memoryPath = getSessionMemoryPath()
|
||||
|
||||
try {
|
||||
const content = await fs.readFile(memoryPath, { encoding: 'utf-8' })
|
||||
|
||||
logEvent('tengu_session_memory_loaded', {
|
||||
content_length: content.length,
|
||||
})
|
||||
|
||||
return content
|
||||
} catch (e: unknown) {
|
||||
if (isFsInaccessible(e)) return null
|
||||
throw e
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the session memory configuration
|
||||
*/
|
||||
export function setSessionMemoryConfig(
|
||||
config: Partial<SessionMemoryConfig>,
|
||||
): void {
|
||||
sessionMemoryConfig = {
|
||||
...sessionMemoryConfig,
|
||||
...config,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current session memory configuration
|
||||
*/
|
||||
export function getSessionMemoryConfig(): SessionMemoryConfig {
|
||||
return { ...sessionMemoryConfig }
|
||||
}
|
||||
|
||||
/**
|
||||
* Record the context size at the time of extraction.
|
||||
* Used to measure context growth for minimumTokensBetweenUpdate threshold.
|
||||
*/
|
||||
export function recordExtractionTokenCount(currentTokenCount: number): void {
|
||||
tokensAtLastExtraction = currentTokenCount
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if session memory has been initialized (met minimumTokensToInit threshold)
|
||||
*/
|
||||
export function isSessionMemoryInitialized(): boolean {
|
||||
return sessionMemoryInitialized
|
||||
}
|
||||
|
||||
/**
|
||||
* Mark session memory as initialized
|
||||
*/
|
||||
export function markSessionMemoryInitialized(): void {
|
||||
sessionMemoryInitialized = true
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if we've met the threshold to initialize session memory.
|
||||
* Uses total context window tokens (same as autocompact) for consistent behavior.
|
||||
*/
|
||||
export function hasMetInitializationThreshold(
|
||||
currentTokenCount: number,
|
||||
): boolean {
|
||||
return currentTokenCount >= sessionMemoryConfig.minimumMessageTokensToInit
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if we've met the threshold for the next update.
|
||||
* Measures actual context window growth since last extraction
|
||||
* (same metric as autocompact and initialization threshold).
|
||||
*/
|
||||
export function hasMetUpdateThreshold(currentTokenCount: number): boolean {
|
||||
const tokensSinceLastExtraction = currentTokenCount - tokensAtLastExtraction
|
||||
return (
|
||||
tokensSinceLastExtraction >= sessionMemoryConfig.minimumTokensBetweenUpdate
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the configured number of tool calls between updates
|
||||
*/
|
||||
export function getToolCallsBetweenUpdates(): number {
|
||||
return sessionMemoryConfig.toolCallsBetweenUpdates
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset session memory state (useful for testing)
|
||||
*/
|
||||
export function resetSessionMemoryState(): void {
|
||||
sessionMemoryConfig = { ...DEFAULT_SESSION_MEMORY_CONFIG }
|
||||
tokensAtLastExtraction = 0
|
||||
sessionMemoryInitialized = false
|
||||
lastSummarizedMessageId = undefined
|
||||
extractionStartedAt = undefined
|
||||
}
|
||||
38
src/services/analytics/config.ts
Normal file
38
src/services/analytics/config.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
/**
|
||||
* Shared analytics configuration
|
||||
*
|
||||
* Common logic for determining when analytics should be disabled
|
||||
* across all analytics systems (Datadog, 1P)
|
||||
*/
|
||||
|
||||
import { isEnvTruthy } from '../../utils/envUtils.js'
|
||||
import { isTelemetryDisabled } from '../../utils/privacyLevel.js'
|
||||
|
||||
/**
|
||||
* Check if analytics operations should be disabled
|
||||
*
|
||||
* Analytics is disabled in the following cases:
|
||||
* - Test environment (NODE_ENV === 'test')
|
||||
* - Third-party cloud providers (Bedrock/Vertex)
|
||||
* - Privacy level is no-telemetry or essential-traffic
|
||||
*/
|
||||
export function isAnalyticsDisabled(): boolean {
|
||||
return (
|
||||
process.env.NODE_ENV === 'test' ||
|
||||
isEnvTruthy(process.env.CLAUDE_CODE_USE_BEDROCK) ||
|
||||
isEnvTruthy(process.env.CLAUDE_CODE_USE_VERTEX) ||
|
||||
isEnvTruthy(process.env.CLAUDE_CODE_USE_FOUNDRY) ||
|
||||
isTelemetryDisabled()
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the feedback survey should be suppressed.
|
||||
*
|
||||
* Unlike isAnalyticsDisabled(), this does NOT block on 3P providers
|
||||
* (Bedrock/Vertex/Foundry). The survey is a local UI prompt with no
|
||||
* transcript data — enterprise customers capture responses via OTEL.
|
||||
*/
|
||||
export function isFeedbackSurveyDisabled(): boolean {
|
||||
return process.env.NODE_ENV === 'test' || isTelemetryDisabled()
|
||||
}
|
||||
307
src/services/analytics/datadog.ts
Normal file
307
src/services/analytics/datadog.ts
Normal file
@@ -0,0 +1,307 @@
|
||||
import axios from 'axios'
|
||||
import { createHash } from 'crypto'
|
||||
import memoize from 'lodash-es/memoize.js'
|
||||
import { getOrCreateUserID } from '../../utils/config.js'
|
||||
import { logError } from '../../utils/log.js'
|
||||
import { getCanonicalName } from '../../utils/model/model.js'
|
||||
import { getAPIProvider } from '../../utils/model/providers.js'
|
||||
import { MODEL_COSTS } from '../../utils/modelCost.js'
|
||||
import { isAnalyticsDisabled } from './config.js'
|
||||
import { getEventMetadata } from './metadata.js'
|
||||
|
||||
const DATADOG_LOGS_ENDPOINT =
|
||||
'https://http-intake.logs.us5.datadoghq.com/api/v2/logs'
|
||||
const DATADOG_CLIENT_TOKEN = 'pubbbf48e6d78dae54bceaa4acf463299bf'
|
||||
const DEFAULT_FLUSH_INTERVAL_MS = 15000
|
||||
const MAX_BATCH_SIZE = 100
|
||||
const NETWORK_TIMEOUT_MS = 5000
|
||||
|
||||
const DATADOG_ALLOWED_EVENTS = new Set([
|
||||
'chrome_bridge_connection_succeeded',
|
||||
'chrome_bridge_connection_failed',
|
||||
'chrome_bridge_disconnected',
|
||||
'chrome_bridge_tool_call_completed',
|
||||
'chrome_bridge_tool_call_error',
|
||||
'chrome_bridge_tool_call_started',
|
||||
'chrome_bridge_tool_call_timeout',
|
||||
'tengu_api_error',
|
||||
'tengu_api_success',
|
||||
'tengu_brief_mode_enabled',
|
||||
'tengu_brief_mode_toggled',
|
||||
'tengu_brief_send',
|
||||
'tengu_cancel',
|
||||
'tengu_compact_failed',
|
||||
'tengu_exit',
|
||||
'tengu_flicker',
|
||||
'tengu_init',
|
||||
'tengu_model_fallback_triggered',
|
||||
'tengu_oauth_error',
|
||||
'tengu_oauth_success',
|
||||
'tengu_oauth_token_refresh_failure',
|
||||
'tengu_oauth_token_refresh_success',
|
||||
'tengu_oauth_token_refresh_lock_acquiring',
|
||||
'tengu_oauth_token_refresh_lock_acquired',
|
||||
'tengu_oauth_token_refresh_starting',
|
||||
'tengu_oauth_token_refresh_completed',
|
||||
'tengu_oauth_token_refresh_lock_releasing',
|
||||
'tengu_oauth_token_refresh_lock_released',
|
||||
'tengu_query_error',
|
||||
'tengu_session_file_read',
|
||||
'tengu_started',
|
||||
'tengu_tool_use_error',
|
||||
'tengu_tool_use_granted_in_prompt_permanent',
|
||||
'tengu_tool_use_granted_in_prompt_temporary',
|
||||
'tengu_tool_use_rejected_in_prompt',
|
||||
'tengu_tool_use_success',
|
||||
'tengu_uncaught_exception',
|
||||
'tengu_unhandled_rejection',
|
||||
'tengu_voice_recording_started',
|
||||
'tengu_voice_toggled',
|
||||
'tengu_team_mem_sync_pull',
|
||||
'tengu_team_mem_sync_push',
|
||||
'tengu_team_mem_sync_started',
|
||||
'tengu_team_mem_entries_capped',
|
||||
])
|
||||
|
||||
const TAG_FIELDS = [
|
||||
'arch',
|
||||
'clientType',
|
||||
'errorType',
|
||||
'http_status_range',
|
||||
'http_status',
|
||||
'kairosActive',
|
||||
'model',
|
||||
'platform',
|
||||
'provider',
|
||||
'skillMode',
|
||||
'subscriptionType',
|
||||
'toolName',
|
||||
'userBucket',
|
||||
'userType',
|
||||
'version',
|
||||
'versionBase',
|
||||
]
|
||||
|
||||
function camelToSnakeCase(str: string): string {
|
||||
return str.replace(/[A-Z]/g, letter => `_${letter.toLowerCase()}`)
|
||||
}
|
||||
|
||||
type DatadogLog = {
|
||||
ddsource: string
|
||||
ddtags: string
|
||||
message: string
|
||||
service: string
|
||||
hostname: string
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
let logBatch: DatadogLog[] = []
|
||||
let flushTimer: NodeJS.Timeout | null = null
|
||||
let datadogInitialized: boolean | null = null
|
||||
|
||||
async function flushLogs(): Promise<void> {
|
||||
if (logBatch.length === 0) return
|
||||
|
||||
const logsToSend = logBatch
|
||||
logBatch = []
|
||||
|
||||
try {
|
||||
await axios.post(DATADOG_LOGS_ENDPOINT, logsToSend, {
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'DD-API-KEY': DATADOG_CLIENT_TOKEN,
|
||||
},
|
||||
timeout: NETWORK_TIMEOUT_MS,
|
||||
})
|
||||
} catch (error) {
|
||||
logError(error)
|
||||
}
|
||||
}
|
||||
|
||||
function scheduleFlush(): void {
|
||||
if (flushTimer) return
|
||||
|
||||
flushTimer = setTimeout(() => {
|
||||
flushTimer = null
|
||||
void flushLogs()
|
||||
}, getFlushIntervalMs()).unref()
|
||||
}
|
||||
|
||||
export const initializeDatadog = memoize(async (): Promise<boolean> => {
|
||||
if (isAnalyticsDisabled()) {
|
||||
datadogInitialized = false
|
||||
return false
|
||||
}
|
||||
|
||||
try {
|
||||
datadogInitialized = true
|
||||
return true
|
||||
} catch (error) {
|
||||
logError(error)
|
||||
datadogInitialized = false
|
||||
return false
|
||||
}
|
||||
})
|
||||
|
||||
/**
|
||||
* Flush remaining Datadog logs and shut down.
|
||||
* Called from gracefulShutdown() before process.exit() since
|
||||
* forceExit() prevents the beforeExit handler from firing.
|
||||
*/
|
||||
export async function shutdownDatadog(): Promise<void> {
|
||||
if (flushTimer) {
|
||||
clearTimeout(flushTimer)
|
||||
flushTimer = null
|
||||
}
|
||||
await flushLogs()
|
||||
}
|
||||
|
||||
// NOTE: use via src/services/analytics/index.ts > logEvent
|
||||
export async function trackDatadogEvent(
|
||||
eventName: string,
|
||||
properties: { [key: string]: boolean | number | undefined },
|
||||
): Promise<void> {
|
||||
if (process.env.NODE_ENV !== 'production') {
|
||||
return
|
||||
}
|
||||
|
||||
// Don't send events for 3P providers (Bedrock, Vertex, Foundry)
|
||||
if (getAPIProvider() !== 'firstParty') {
|
||||
return
|
||||
}
|
||||
|
||||
// Fast path: use cached result if available to avoid await overhead
|
||||
let initialized = datadogInitialized
|
||||
if (initialized === null) {
|
||||
initialized = await initializeDatadog()
|
||||
}
|
||||
if (!initialized || !DATADOG_ALLOWED_EVENTS.has(eventName)) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
const metadata = await getEventMetadata({
|
||||
model: properties.model,
|
||||
betas: properties.betas,
|
||||
})
|
||||
// Destructure to avoid duplicate envContext (once nested, once flattened)
|
||||
const { envContext, ...restMetadata } = metadata
|
||||
const allData: Record<string, unknown> = {
|
||||
...restMetadata,
|
||||
...envContext,
|
||||
...properties,
|
||||
userBucket: getUserBucket(),
|
||||
}
|
||||
|
||||
// Normalize MCP tool names to "mcp" for cardinality reduction
|
||||
if (
|
||||
typeof allData.toolName === 'string' &&
|
||||
allData.toolName.startsWith('mcp__')
|
||||
) {
|
||||
allData.toolName = 'mcp'
|
||||
}
|
||||
|
||||
// Normalize model names for cardinality reduction (external users only)
|
||||
if (process.env.USER_TYPE !== 'ant' && typeof allData.model === 'string') {
|
||||
const shortName = getCanonicalName(allData.model.replace(/\[1m]$/i, ''))
|
||||
allData.model = shortName in MODEL_COSTS ? shortName : 'other'
|
||||
}
|
||||
|
||||
// Truncate dev version to base + date (remove timestamp and sha for cardinality reduction)
|
||||
// e.g. "2.0.53-dev.20251124.t173302.sha526cc6a" -> "2.0.53-dev.20251124"
|
||||
if (typeof allData.version === 'string') {
|
||||
allData.version = allData.version.replace(
|
||||
/^(\d+\.\d+\.\d+-dev\.\d{8})\.t\d+\.sha[a-f0-9]+$/,
|
||||
'$1',
|
||||
)
|
||||
}
|
||||
|
||||
// Transform status to http_status and http_status_range to avoid Datadog reserved field
|
||||
if (allData.status !== undefined && allData.status !== null) {
|
||||
const statusCode = String(allData.status)
|
||||
allData.http_status = statusCode
|
||||
|
||||
// Determine status range (1xx, 2xx, 3xx, 4xx, 5xx)
|
||||
const firstDigit = statusCode.charAt(0)
|
||||
if (firstDigit >= '1' && firstDigit <= '5') {
|
||||
allData.http_status_range = `${firstDigit}xx`
|
||||
}
|
||||
|
||||
// Remove original status field to avoid conflict with Datadog's reserved field
|
||||
delete allData.status
|
||||
}
|
||||
|
||||
// Build ddtags with high-cardinality fields for filtering.
|
||||
// event:<name> is prepended so the event name is searchable via the
|
||||
// log search API — the `message` field (where eventName also lives)
|
||||
// is a DD reserved field and is NOT queryable from dashboard widget
|
||||
// queries or the aggregation API. See scripts/release/MONITORING.md.
|
||||
const allDataRecord = allData
|
||||
const tags = [
|
||||
`event:${eventName}`,
|
||||
...TAG_FIELDS.filter(
|
||||
field =>
|
||||
allDataRecord[field] !== undefined && allDataRecord[field] !== null,
|
||||
).map(field => `${camelToSnakeCase(field)}:${allDataRecord[field]}`),
|
||||
]
|
||||
|
||||
const log: DatadogLog = {
|
||||
ddsource: 'nodejs',
|
||||
ddtags: tags.join(','),
|
||||
message: eventName,
|
||||
service: 'claude-code',
|
||||
hostname: 'claude-code',
|
||||
env: process.env.USER_TYPE,
|
||||
}
|
||||
|
||||
// Add all fields as searchable attributes (not duplicated in tags)
|
||||
for (const [key, value] of Object.entries(allData)) {
|
||||
if (value !== undefined && value !== null) {
|
||||
log[camelToSnakeCase(key)] = value
|
||||
}
|
||||
}
|
||||
|
||||
logBatch.push(log)
|
||||
|
||||
// Flush immediately if batch is full, otherwise schedule
|
||||
if (logBatch.length >= MAX_BATCH_SIZE) {
|
||||
if (flushTimer) {
|
||||
clearTimeout(flushTimer)
|
||||
flushTimer = null
|
||||
}
|
||||
void flushLogs()
|
||||
} else {
|
||||
scheduleFlush()
|
||||
}
|
||||
} catch (error) {
|
||||
logError(error)
|
||||
}
|
||||
}
|
||||
|
||||
const NUM_USER_BUCKETS = 30
|
||||
|
||||
/**
|
||||
* Gets a 'bucket' that the user ID falls into.
|
||||
*
|
||||
* For alerting purposes, we want to alert on the number of users impacted
|
||||
* by an issue, rather than the number of events- often a small number of users
|
||||
* can generate a large number of events (e.g. due to retries). To approximate
|
||||
* this without ruining cardinality by counting user IDs directly, we hash the user ID
|
||||
* and assign it to one of a fixed number of buckets.
|
||||
*
|
||||
* This allows us to estimate the number of unique users by counting unique buckets,
|
||||
* while preserving user privacy and reducing cardinality.
|
||||
*/
|
||||
const getUserBucket = memoize((): number => {
|
||||
const userId = getOrCreateUserID()
|
||||
const hash = createHash('sha256').update(userId).digest('hex')
|
||||
return parseInt(hash.slice(0, 8), 16) % NUM_USER_BUCKETS
|
||||
})
|
||||
|
||||
function getFlushIntervalMs(): number {
|
||||
// Allow tests to override to not block on the default flush interval.
|
||||
return (
|
||||
parseInt(process.env.CLAUDE_CODE_DATADOG_FLUSH_INTERVAL_MS || '', 10) ||
|
||||
DEFAULT_FLUSH_INTERVAL_MS
|
||||
)
|
||||
}
|
||||
449
src/services/analytics/firstPartyEventLogger.ts
Normal file
449
src/services/analytics/firstPartyEventLogger.ts
Normal file
@@ -0,0 +1,449 @@
|
||||
import type { AnyValueMap, Logger, logs } from '@opentelemetry/api-logs'
|
||||
import { resourceFromAttributes } from '@opentelemetry/resources'
|
||||
import {
|
||||
BatchLogRecordProcessor,
|
||||
LoggerProvider,
|
||||
} from '@opentelemetry/sdk-logs'
|
||||
import {
|
||||
ATTR_SERVICE_NAME,
|
||||
ATTR_SERVICE_VERSION,
|
||||
} from '@opentelemetry/semantic-conventions'
|
||||
import { randomUUID } from 'crypto'
|
||||
import { isEqual } from 'lodash-es'
|
||||
import { getOrCreateUserID } from '../../utils/config.js'
|
||||
import { logForDebugging } from '../../utils/debug.js'
|
||||
import { logError } from '../../utils/log.js'
|
||||
import { getPlatform, getWslVersion } from '../../utils/platform.js'
|
||||
import { jsonStringify } from '../../utils/slowOperations.js'
|
||||
import { profileCheckpoint } from '../../utils/startupProfiler.js'
|
||||
import { getCoreUserData } from '../../utils/user.js'
|
||||
import { isAnalyticsDisabled } from './config.js'
|
||||
import { FirstPartyEventLoggingExporter } from './firstPartyEventLoggingExporter.js'
|
||||
import type { GrowthBookUserAttributes } from './growthbook.js'
|
||||
import { getDynamicConfig_CACHED_MAY_BE_STALE } from './growthbook.js'
|
||||
import { getEventMetadata } from './metadata.js'
|
||||
import { isSinkKilled } from './sinkKillswitch.js'
|
||||
|
||||
/**
|
||||
* Configuration for sampling individual event types.
|
||||
* Each event name maps to an object containing sample_rate (0-1).
|
||||
* Events not in the config are logged at 100% rate.
|
||||
*/
|
||||
export type EventSamplingConfig = {
|
||||
[eventName: string]: {
|
||||
sample_rate: number
|
||||
}
|
||||
}
|
||||
|
||||
const EVENT_SAMPLING_CONFIG_NAME = 'tengu_event_sampling_config'
|
||||
/**
|
||||
* Get the event sampling configuration from GrowthBook.
|
||||
* Uses cached value if available, updates cache in background.
|
||||
*/
|
||||
export function getEventSamplingConfig(): EventSamplingConfig {
|
||||
return getDynamicConfig_CACHED_MAY_BE_STALE<EventSamplingConfig>(
|
||||
EVENT_SAMPLING_CONFIG_NAME,
|
||||
{},
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Determine if an event should be sampled based on its sample rate.
|
||||
* Returns the sample rate if sampled, null if not sampled.
|
||||
*
|
||||
* @param eventName - Name of the event to check
|
||||
* @returns The sample_rate if event should be logged, null if it should be dropped
|
||||
*/
|
||||
export function shouldSampleEvent(eventName: string): number | null {
|
||||
const config = getEventSamplingConfig()
|
||||
const eventConfig = config[eventName]
|
||||
|
||||
// If no config for this event, log at 100% rate (no sampling)
|
||||
if (!eventConfig) {
|
||||
return null
|
||||
}
|
||||
|
||||
const sampleRate = eventConfig.sample_rate
|
||||
|
||||
// Validate sample rate is in valid range
|
||||
if (typeof sampleRate !== 'number' || sampleRate < 0 || sampleRate > 1) {
|
||||
return null
|
||||
}
|
||||
|
||||
// Sample rate of 1 means log everything (no need to add metadata)
|
||||
if (sampleRate >= 1) {
|
||||
return null
|
||||
}
|
||||
|
||||
// Sample rate of 0 means drop everything
|
||||
if (sampleRate <= 0) {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Randomly decide whether to sample this event
|
||||
return Math.random() < sampleRate ? sampleRate : 0
|
||||
}
|
||||
|
||||
const BATCH_CONFIG_NAME = 'tengu_1p_event_batch_config'
|
||||
type BatchConfig = {
|
||||
scheduledDelayMillis?: number
|
||||
maxExportBatchSize?: number
|
||||
maxQueueSize?: number
|
||||
skipAuth?: boolean
|
||||
maxAttempts?: number
|
||||
path?: string
|
||||
baseUrl?: string
|
||||
}
|
||||
function getBatchConfig(): BatchConfig {
|
||||
return getDynamicConfig_CACHED_MAY_BE_STALE<BatchConfig>(
|
||||
BATCH_CONFIG_NAME,
|
||||
{},
|
||||
)
|
||||
}
|
||||
|
||||
// Module-local state for event logging (not exposed globally)
|
||||
let firstPartyEventLogger: ReturnType<typeof logs.getLogger> | null = null
|
||||
let firstPartyEventLoggerProvider: LoggerProvider | null = null
|
||||
// Last batch config used to construct the provider — used by
|
||||
// reinitialize1PEventLoggingIfConfigChanged to decide whether a rebuild is
|
||||
// needed when GrowthBook refreshes.
|
||||
let lastBatchConfig: BatchConfig | null = null
|
||||
/**
|
||||
* Flush and shutdown the 1P event logger.
|
||||
* This should be called as the final step before process exit to ensure
|
||||
* all events (including late ones from API responses) are exported.
|
||||
*/
|
||||
export async function shutdown1PEventLogging(): Promise<void> {
|
||||
if (!firstPartyEventLoggerProvider) {
|
||||
return
|
||||
}
|
||||
try {
|
||||
await firstPartyEventLoggerProvider.shutdown()
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
logForDebugging('1P event logging: final shutdown complete')
|
||||
}
|
||||
} catch {
|
||||
// Ignore shutdown errors
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if 1P event logging is enabled.
|
||||
* Respects the same opt-outs as other analytics sinks:
|
||||
* - Test environment
|
||||
* - Third-party cloud providers (Bedrock/Vertex)
|
||||
* - Global telemetry opt-outs
|
||||
* - Non-essential traffic disabled
|
||||
*
|
||||
* Note: Unlike BigQuery metrics, event logging does NOT check organization-level
|
||||
* metrics opt-out via API. It follows the same pattern as Statsig event logging.
|
||||
*/
|
||||
export function is1PEventLoggingEnabled(): boolean {
|
||||
// Respect standard analytics opt-outs
|
||||
return !isAnalyticsDisabled()
|
||||
}
|
||||
|
||||
/**
|
||||
* Log a 1st-party event for internal analytics (async version).
|
||||
* Events are batched and exported to /api/event_logging/batch
|
||||
*
|
||||
* This enriches the event with core metadata (model, session, env context, etc.)
|
||||
* at log time, similar to logEventToStatsig.
|
||||
*
|
||||
* @param eventName - Name of the event (e.g., 'tengu_api_query')
|
||||
* @param metadata - Additional metadata for the event (intentionally no strings, to avoid accidentally logging code/filepaths)
|
||||
*/
|
||||
async function logEventTo1PAsync(
|
||||
firstPartyEventLogger: Logger,
|
||||
eventName: string,
|
||||
metadata: Record<string, number | boolean | undefined> = {},
|
||||
): Promise<void> {
|
||||
try {
|
||||
// Enrich with core metadata at log time (similar to Statsig pattern)
|
||||
const coreMetadata = await getEventMetadata({
|
||||
model: metadata.model,
|
||||
betas: metadata.betas,
|
||||
})
|
||||
|
||||
// Build attributes - OTel supports nested objects natively via AnyValueMap
|
||||
// Cast through unknown since our nested objects are structurally compatible
|
||||
// with AnyValue but TS doesn't recognize it due to missing index signatures
|
||||
const attributes = {
|
||||
event_name: eventName,
|
||||
event_id: randomUUID(),
|
||||
// Pass objects directly - no JSON serialization needed
|
||||
core_metadata: coreMetadata,
|
||||
user_metadata: getCoreUserData(true),
|
||||
event_metadata: metadata,
|
||||
} as unknown as AnyValueMap
|
||||
|
||||
// Add user_id if available
|
||||
const userId = getOrCreateUserID()
|
||||
if (userId) {
|
||||
attributes.user_id = userId
|
||||
}
|
||||
|
||||
// Debug logging when debug mode is enabled
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
logForDebugging(
|
||||
`[ANT-ONLY] 1P event: ${eventName} ${jsonStringify(metadata, null, 0)}`,
|
||||
)
|
||||
}
|
||||
|
||||
// Emit log record
|
||||
firstPartyEventLogger.emit({
|
||||
body: eventName,
|
||||
attributes,
|
||||
})
|
||||
} catch (e) {
|
||||
if (process.env.NODE_ENV === 'development') {
|
||||
throw e
|
||||
}
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
logError(e as Error)
|
||||
}
|
||||
// swallow
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Log a 1st-party event for internal analytics.
|
||||
* Events are batched and exported to /api/event_logging/batch
|
||||
*
|
||||
* @param eventName - Name of the event (e.g., 'tengu_api_query')
|
||||
* @param metadata - Additional metadata for the event (intentionally no strings, to avoid accidentally logging code/filepaths)
|
||||
*/
|
||||
export function logEventTo1P(
|
||||
eventName: string,
|
||||
metadata: Record<string, number | boolean | undefined> = {},
|
||||
): void {
|
||||
if (!is1PEventLoggingEnabled()) {
|
||||
return
|
||||
}
|
||||
|
||||
if (!firstPartyEventLogger || isSinkKilled('firstParty')) {
|
||||
return
|
||||
}
|
||||
|
||||
// Fire and forget - don't block on metadata enrichment
|
||||
void logEventTo1PAsync(firstPartyEventLogger, eventName, metadata)
|
||||
}
|
||||
|
||||
/**
|
||||
* GrowthBook experiment event data for logging
|
||||
*/
|
||||
export type GrowthBookExperimentData = {
|
||||
experimentId: string
|
||||
variationId: number
|
||||
userAttributes?: GrowthBookUserAttributes
|
||||
experimentMetadata?: Record<string, unknown>
|
||||
}
|
||||
|
||||
// api.anthropic.com only serves the "production" GrowthBook environment
|
||||
// (see starling/starling/cli/cli.py DEFAULT_ENVIRONMENTS). Staging and
|
||||
// development environments are not exported to the prod API.
|
||||
function getEnvironmentForGrowthBook(): string {
|
||||
return 'production'
|
||||
}
|
||||
|
||||
/**
|
||||
* Log a GrowthBook experiment assignment event to 1P.
|
||||
* Events are batched and exported to /api/event_logging/batch
|
||||
*
|
||||
* @param data - GrowthBook experiment assignment data
|
||||
*/
|
||||
export function logGrowthBookExperimentTo1P(
|
||||
data: GrowthBookExperimentData,
|
||||
): void {
|
||||
if (!is1PEventLoggingEnabled()) {
|
||||
return
|
||||
}
|
||||
|
||||
if (!firstPartyEventLogger || isSinkKilled('firstParty')) {
|
||||
return
|
||||
}
|
||||
|
||||
const userId = getOrCreateUserID()
|
||||
const { accountUuid, organizationUuid } = getCoreUserData(true)
|
||||
|
||||
// Build attributes for GrowthbookExperimentEvent
|
||||
const attributes = {
|
||||
event_type: 'GrowthbookExperimentEvent',
|
||||
event_id: randomUUID(),
|
||||
experiment_id: data.experimentId,
|
||||
variation_id: data.variationId,
|
||||
...(userId && { device_id: userId }),
|
||||
...(accountUuid && { account_uuid: accountUuid }),
|
||||
...(organizationUuid && { organization_uuid: organizationUuid }),
|
||||
...(data.userAttributes && {
|
||||
session_id: data.userAttributes.sessionId,
|
||||
user_attributes: jsonStringify(data.userAttributes),
|
||||
}),
|
||||
...(data.experimentMetadata && {
|
||||
experiment_metadata: jsonStringify(data.experimentMetadata),
|
||||
}),
|
||||
environment: getEnvironmentForGrowthBook(),
|
||||
}
|
||||
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
logForDebugging(
|
||||
`[ANT-ONLY] 1P GrowthBook experiment: ${data.experimentId} variation=${data.variationId}`,
|
||||
)
|
||||
}
|
||||
|
||||
firstPartyEventLogger.emit({
|
||||
body: 'growthbook_experiment',
|
||||
attributes,
|
||||
})
|
||||
}
|
||||
|
||||
const DEFAULT_LOGS_EXPORT_INTERVAL_MS = 10000
|
||||
const DEFAULT_MAX_EXPORT_BATCH_SIZE = 200
|
||||
const DEFAULT_MAX_QUEUE_SIZE = 8192
|
||||
|
||||
/**
|
||||
* Initialize 1P event logging infrastructure.
|
||||
* This creates a separate LoggerProvider for internal event logging,
|
||||
* independent of customer OTLP telemetry.
|
||||
*
|
||||
* This uses its own minimal resource configuration with just the attributes
|
||||
* we need for internal analytics (service name, version, platform info).
|
||||
*/
|
||||
export function initialize1PEventLogging(): void {
|
||||
profileCheckpoint('1p_event_logging_start')
|
||||
const enabled = is1PEventLoggingEnabled()
|
||||
|
||||
if (!enabled) {
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
logForDebugging('1P event logging not enabled')
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Fetch batch processor configuration from GrowthBook dynamic config
|
||||
// Uses cached value if available, refreshes in background
|
||||
const batchConfig = getBatchConfig()
|
||||
lastBatchConfig = batchConfig
|
||||
profileCheckpoint('1p_event_after_growthbook_config')
|
||||
|
||||
const scheduledDelayMillis =
|
||||
batchConfig.scheduledDelayMillis ||
|
||||
parseInt(
|
||||
process.env.OTEL_LOGS_EXPORT_INTERVAL ||
|
||||
DEFAULT_LOGS_EXPORT_INTERVAL_MS.toString(),
|
||||
)
|
||||
|
||||
const maxExportBatchSize =
|
||||
batchConfig.maxExportBatchSize || DEFAULT_MAX_EXPORT_BATCH_SIZE
|
||||
|
||||
const maxQueueSize = batchConfig.maxQueueSize || DEFAULT_MAX_QUEUE_SIZE
|
||||
|
||||
// Build our own resource for 1P event logging with minimal attributes
|
||||
const platform = getPlatform()
|
||||
const attributes: Record<string, string> = {
|
||||
[ATTR_SERVICE_NAME]: 'claude-code',
|
||||
[ATTR_SERVICE_VERSION]: MACRO.VERSION,
|
||||
}
|
||||
|
||||
// Add WSL-specific attributes if running on WSL
|
||||
if (platform === 'wsl') {
|
||||
const wslVersion = getWslVersion()
|
||||
if (wslVersion) {
|
||||
attributes['wsl.version'] = wslVersion
|
||||
}
|
||||
}
|
||||
|
||||
const resource = resourceFromAttributes(attributes)
|
||||
|
||||
// Create a new LoggerProvider with the EventLoggingExporter
|
||||
// NOTE: This is kept separate from customer telemetry logs to ensure
|
||||
// internal events don't leak to customer endpoints and vice versa.
|
||||
// We don't register this globally - it's only used for internal event logging.
|
||||
const eventLoggingExporter = new FirstPartyEventLoggingExporter({
|
||||
maxBatchSize: maxExportBatchSize,
|
||||
skipAuth: batchConfig.skipAuth,
|
||||
maxAttempts: batchConfig.maxAttempts,
|
||||
path: batchConfig.path,
|
||||
baseUrl: batchConfig.baseUrl,
|
||||
isKilled: () => isSinkKilled('firstParty'),
|
||||
})
|
||||
firstPartyEventLoggerProvider = new LoggerProvider({
|
||||
resource,
|
||||
processors: [
|
||||
new BatchLogRecordProcessor(eventLoggingExporter, {
|
||||
scheduledDelayMillis,
|
||||
maxExportBatchSize,
|
||||
maxQueueSize,
|
||||
}),
|
||||
],
|
||||
})
|
||||
|
||||
// Initialize event logger from our internal provider (NOT from global API)
|
||||
// IMPORTANT: We must get the logger from our local provider, not logs.getLogger()
|
||||
// because logs.getLogger() returns a logger from the global provider, which is
|
||||
// separate and used for customer telemetry.
|
||||
firstPartyEventLogger = firstPartyEventLoggerProvider.getLogger(
|
||||
'com.anthropic.claude_code.events',
|
||||
MACRO.VERSION,
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Rebuild the 1P event logging pipeline if the batch config changed.
|
||||
* Register this with onGrowthBookRefresh so long-running sessions pick up
|
||||
* changes to batch size, delay, endpoint, etc.
|
||||
*
|
||||
* Event-loss safety:
|
||||
* 1. Null the logger first — concurrent logEventTo1P() calls hit the
|
||||
* !firstPartyEventLogger guard and bail during the swap window. This drops
|
||||
* a handful of events but prevents emitting to a draining provider.
|
||||
* 2. forceFlush() drains the old BatchLogRecordProcessor buffer to the
|
||||
* exporter. Export failures go to disk at getCurrentBatchFilePath() which
|
||||
* is keyed by module-level BATCH_UUID + sessionId — unchanged across
|
||||
* reinit — so the NEW exporter's disk-backed retry picks them up.
|
||||
* 3. Swap to new provider/logger; old provider shutdown runs in background
|
||||
* (buffer already drained, just cleanup).
|
||||
*/
|
||||
export async function reinitialize1PEventLoggingIfConfigChanged(): Promise<void> {
|
||||
if (!is1PEventLoggingEnabled() || !firstPartyEventLoggerProvider) {
|
||||
return
|
||||
}
|
||||
|
||||
const newConfig = getBatchConfig()
|
||||
|
||||
if (isEqual(newConfig, lastBatchConfig)) {
|
||||
return
|
||||
}
|
||||
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
logForDebugging(
|
||||
`1P event logging: ${BATCH_CONFIG_NAME} changed, reinitializing`,
|
||||
)
|
||||
}
|
||||
|
||||
const oldProvider = firstPartyEventLoggerProvider
|
||||
const oldLogger = firstPartyEventLogger
|
||||
firstPartyEventLogger = null
|
||||
|
||||
try {
|
||||
await oldProvider.forceFlush()
|
||||
} catch {
|
||||
// Export failures are already on disk; new exporter will retry them.
|
||||
}
|
||||
|
||||
firstPartyEventLoggerProvider = null
|
||||
try {
|
||||
initialize1PEventLogging()
|
||||
} catch (e) {
|
||||
// Restore so the next GrowthBook refresh can retry. oldProvider was
|
||||
// only forceFlush()'d, not shut down — it's still functional. Without
|
||||
// this, both stay null and the !firstPartyEventLoggerProvider gate at
|
||||
// the top makes recovery impossible.
|
||||
firstPartyEventLoggerProvider = oldProvider
|
||||
firstPartyEventLogger = oldLogger
|
||||
logError(e)
|
||||
return
|
||||
}
|
||||
|
||||
void oldProvider.shutdown().catch(() => {})
|
||||
}
|
||||
806
src/services/analytics/firstPartyEventLoggingExporter.ts
Normal file
806
src/services/analytics/firstPartyEventLoggingExporter.ts
Normal file
@@ -0,0 +1,806 @@
|
||||
import type { HrTime } from '@opentelemetry/api'
|
||||
import { type ExportResult, ExportResultCode } from '@opentelemetry/core'
|
||||
import type {
|
||||
LogRecordExporter,
|
||||
ReadableLogRecord,
|
||||
} from '@opentelemetry/sdk-logs'
|
||||
import axios from 'axios'
|
||||
import { randomUUID } from 'crypto'
|
||||
import { appendFile, mkdir, readdir, unlink, writeFile } from 'fs/promises'
|
||||
import * as path from 'path'
|
||||
import type { CoreUserData } from 'src/utils/user.js'
|
||||
import {
|
||||
getIsNonInteractiveSession,
|
||||
getSessionId,
|
||||
} from '../../bootstrap/state.js'
|
||||
import { ClaudeCodeInternalEvent } from '../../types/generated/events_mono/claude_code/v1/claude_code_internal_event.js'
|
||||
import { GrowthbookExperimentEvent } from '../../types/generated/events_mono/growthbook/v1/growthbook_experiment_event.js'
|
||||
import {
|
||||
getClaudeAIOAuthTokens,
|
||||
hasProfileScope,
|
||||
isClaudeAISubscriber,
|
||||
} from '../../utils/auth.js'
|
||||
import { checkHasTrustDialogAccepted } from '../../utils/config.js'
|
||||
import { logForDebugging } from '../../utils/debug.js'
|
||||
import { getClaudeConfigHomeDir } from '../../utils/envUtils.js'
|
||||
import { errorMessage, isFsInaccessible, toError } from '../../utils/errors.js'
|
||||
import { getAuthHeaders } from '../../utils/http.js'
|
||||
import { readJSONLFile } from '../../utils/json.js'
|
||||
import { logError } from '../../utils/log.js'
|
||||
import { sleep } from '../../utils/sleep.js'
|
||||
import { jsonStringify } from '../../utils/slowOperations.js'
|
||||
import { getClaudeCodeUserAgent } from '../../utils/userAgent.js'
|
||||
import { isOAuthTokenExpired } from '../oauth/client.js'
|
||||
import { stripProtoFields } from './index.js'
|
||||
import { type EventMetadata, to1PEventFormat } from './metadata.js'
|
||||
|
||||
// Unique ID for this process run - used to isolate failed event files between runs
|
||||
const BATCH_UUID = randomUUID()
|
||||
|
||||
// File prefix for failed event storage
|
||||
const FILE_PREFIX = '1p_failed_events.'
|
||||
|
||||
// Storage directory for failed events - evaluated at runtime to respect CLAUDE_CONFIG_DIR in tests
|
||||
function getStorageDir(): string {
|
||||
return path.join(getClaudeConfigHomeDir(), 'telemetry')
|
||||
}
|
||||
|
||||
// API envelope - event_data is the JSON output from proto toJSON()
|
||||
type FirstPartyEventLoggingEvent = {
|
||||
event_type: 'ClaudeCodeInternalEvent' | 'GrowthbookExperimentEvent'
|
||||
event_data: unknown
|
||||
}
|
||||
|
||||
type FirstPartyEventLoggingPayload = {
|
||||
events: FirstPartyEventLoggingEvent[]
|
||||
}
|
||||
|
||||
/**
|
||||
* Exporter for 1st-party event logging to /api/event_logging/batch.
|
||||
*
|
||||
* Export cycles are controlled by OpenTelemetry's BatchLogRecordProcessor, which
|
||||
* triggers export() when either:
|
||||
* - Time interval elapses (default: 5 seconds via scheduledDelayMillis)
|
||||
* - Batch size is reached (default: 200 events via maxExportBatchSize)
|
||||
*
|
||||
* This exporter adds resilience on top:
|
||||
* - Append-only log for failed events (concurrency-safe)
|
||||
* - Quadratic backoff retry for failed events, dropped after maxAttempts
|
||||
* - Immediate retry of queued events when any export succeeds (endpoint is healthy)
|
||||
* - Chunking large event sets into smaller batches
|
||||
* - Auth fallback: retries without auth on 401 errors
|
||||
*/
|
||||
export class FirstPartyEventLoggingExporter implements LogRecordExporter {
|
||||
private readonly endpoint: string
|
||||
private readonly timeout: number
|
||||
private readonly maxBatchSize: number
|
||||
private readonly skipAuth: boolean
|
||||
private readonly batchDelayMs: number
|
||||
private readonly baseBackoffDelayMs: number
|
||||
private readonly maxBackoffDelayMs: number
|
||||
private readonly maxAttempts: number
|
||||
private readonly isKilled: () => boolean
|
||||
private pendingExports: Promise<void>[] = []
|
||||
private isShutdown = false
|
||||
private readonly schedule: (
|
||||
fn: () => Promise<void>,
|
||||
delayMs: number,
|
||||
) => () => void
|
||||
private cancelBackoff: (() => void) | null = null
|
||||
private attempts = 0
|
||||
private isRetrying = false
|
||||
private lastExportErrorContext: string | undefined
|
||||
|
||||
constructor(
|
||||
options: {
|
||||
timeout?: number
|
||||
maxBatchSize?: number
|
||||
skipAuth?: boolean
|
||||
batchDelayMs?: number
|
||||
baseBackoffDelayMs?: number
|
||||
maxBackoffDelayMs?: number
|
||||
maxAttempts?: number
|
||||
path?: string
|
||||
baseUrl?: string
|
||||
// Injected killswitch probe. Checked per-POST so that disabling the
|
||||
// firstParty sink also stops backoff retries (not just new emits).
|
||||
// Passed in rather than imported to avoid a cycle with firstPartyEventLogger.ts.
|
||||
isKilled?: () => boolean
|
||||
schedule?: (fn: () => Promise<void>, delayMs: number) => () => void
|
||||
} = {},
|
||||
) {
|
||||
// Default: prod, except when ANTHROPIC_BASE_URL is explicitly staging.
|
||||
// Overridable via tengu_1p_event_batch_config.baseUrl.
|
||||
const baseUrl =
|
||||
options.baseUrl ||
|
||||
(process.env.ANTHROPIC_BASE_URL === 'https://api-staging.anthropic.com'
|
||||
? 'https://api-staging.anthropic.com'
|
||||
: 'https://api.anthropic.com')
|
||||
|
||||
this.endpoint = `${baseUrl}${options.path || '/api/event_logging/batch'}`
|
||||
|
||||
this.timeout = options.timeout || 10000
|
||||
this.maxBatchSize = options.maxBatchSize || 200
|
||||
this.skipAuth = options.skipAuth ?? false
|
||||
this.batchDelayMs = options.batchDelayMs || 100
|
||||
this.baseBackoffDelayMs = options.baseBackoffDelayMs || 500
|
||||
this.maxBackoffDelayMs = options.maxBackoffDelayMs || 30000
|
||||
this.maxAttempts = options.maxAttempts ?? 8
|
||||
this.isKilled = options.isKilled ?? (() => false)
|
||||
this.schedule =
|
||||
options.schedule ??
|
||||
((fn, ms) => {
|
||||
const t = setTimeout(fn, ms)
|
||||
return () => clearTimeout(t)
|
||||
})
|
||||
|
||||
// Retry any failed events from previous runs of this session (in background)
|
||||
void this.retryPreviousBatches()
|
||||
}
|
||||
|
||||
// Expose for testing
|
||||
async getQueuedEventCount(): Promise<number> {
|
||||
return (await this.loadEventsFromCurrentBatch()).length
|
||||
}
|
||||
|
||||
// --- Storage helpers ---
|
||||
|
||||
private getCurrentBatchFilePath(): string {
|
||||
return path.join(
|
||||
getStorageDir(),
|
||||
`${FILE_PREFIX}${getSessionId()}.${BATCH_UUID}.json`,
|
||||
)
|
||||
}
|
||||
|
||||
private async loadEventsFromFile(
|
||||
filePath: string,
|
||||
): Promise<FirstPartyEventLoggingEvent[]> {
|
||||
try {
|
||||
return await readJSONLFile<FirstPartyEventLoggingEvent>(filePath)
|
||||
} catch {
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
private async loadEventsFromCurrentBatch(): Promise<
|
||||
FirstPartyEventLoggingEvent[]
|
||||
> {
|
||||
return this.loadEventsFromFile(this.getCurrentBatchFilePath())
|
||||
}
|
||||
|
||||
private async saveEventsToFile(
|
||||
filePath: string,
|
||||
events: FirstPartyEventLoggingEvent[],
|
||||
): Promise<void> {
|
||||
try {
|
||||
if (events.length === 0) {
|
||||
try {
|
||||
await unlink(filePath)
|
||||
} catch {
|
||||
// File doesn't exist, nothing to delete
|
||||
}
|
||||
} else {
|
||||
// Ensure storage directory exists
|
||||
await mkdir(getStorageDir(), { recursive: true })
|
||||
// Write as JSON lines (one event per line)
|
||||
const content = events.map(e => jsonStringify(e)).join('\n') + '\n'
|
||||
await writeFile(filePath, content, 'utf8')
|
||||
}
|
||||
} catch (error) {
|
||||
logError(error)
|
||||
}
|
||||
}
|
||||
|
||||
private async appendEventsToFile(
|
||||
filePath: string,
|
||||
events: FirstPartyEventLoggingEvent[],
|
||||
): Promise<void> {
|
||||
if (events.length === 0) return
|
||||
try {
|
||||
// Ensure storage directory exists
|
||||
await mkdir(getStorageDir(), { recursive: true })
|
||||
// Append as JSON lines (one event per line) - atomic on most filesystems
|
||||
const content = events.map(e => jsonStringify(e)).join('\n') + '\n'
|
||||
await appendFile(filePath, content, 'utf8')
|
||||
} catch (error) {
|
||||
logError(error)
|
||||
}
|
||||
}
|
||||
|
||||
private async deleteFile(filePath: string): Promise<void> {
|
||||
try {
|
||||
await unlink(filePath)
|
||||
} catch {
|
||||
// File doesn't exist or can't be deleted, ignore
|
||||
}
|
||||
}
|
||||
|
||||
// --- Previous batch retry (startup) ---
|
||||
|
||||
private async retryPreviousBatches(): Promise<void> {
|
||||
try {
|
||||
const prefix = `${FILE_PREFIX}${getSessionId()}.`
|
||||
let files: string[]
|
||||
try {
|
||||
files = (await readdir(getStorageDir()))
|
||||
.filter((f: string) => f.startsWith(prefix) && f.endsWith('.json'))
|
||||
.filter((f: string) => !f.includes(BATCH_UUID)) // Exclude current batch
|
||||
} catch (e) {
|
||||
if (isFsInaccessible(e)) return
|
||||
throw e
|
||||
}
|
||||
|
||||
for (const file of files) {
|
||||
const filePath = path.join(getStorageDir(), file)
|
||||
void this.retryFileInBackground(filePath)
|
||||
}
|
||||
} catch (error) {
|
||||
logError(error)
|
||||
}
|
||||
}
|
||||
|
||||
private async retryFileInBackground(filePath: string): Promise<void> {
|
||||
if (this.attempts >= this.maxAttempts) {
|
||||
await this.deleteFile(filePath)
|
||||
return
|
||||
}
|
||||
|
||||
const events = await this.loadEventsFromFile(filePath)
|
||||
if (events.length === 0) {
|
||||
await this.deleteFile(filePath)
|
||||
return
|
||||
}
|
||||
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
logForDebugging(
|
||||
`1P event logging: retrying ${events.length} events from previous batch`,
|
||||
)
|
||||
}
|
||||
|
||||
const failedEvents = await this.sendEventsInBatches(events)
|
||||
if (failedEvents.length === 0) {
|
||||
await this.deleteFile(filePath)
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
logForDebugging('1P event logging: previous batch retry succeeded')
|
||||
}
|
||||
} else {
|
||||
// Save only the failed events back (not all original events)
|
||||
await this.saveEventsToFile(filePath, failedEvents)
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
logForDebugging(
|
||||
`1P event logging: previous batch retry failed, ${failedEvents.length} events remain`,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async export(
|
||||
logs: ReadableLogRecord[],
|
||||
resultCallback: (result: ExportResult) => void,
|
||||
): Promise<void> {
|
||||
if (this.isShutdown) {
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
logForDebugging(
|
||||
'1P event logging export failed: Exporter has been shutdown',
|
||||
)
|
||||
}
|
||||
resultCallback({
|
||||
code: ExportResultCode.FAILED,
|
||||
error: new Error('Exporter has been shutdown'),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
const exportPromise = this.doExport(logs, resultCallback)
|
||||
this.pendingExports.push(exportPromise)
|
||||
|
||||
// Clean up completed exports
|
||||
void exportPromise.finally(() => {
|
||||
const index = this.pendingExports.indexOf(exportPromise)
|
||||
if (index > -1) {
|
||||
void this.pendingExports.splice(index, 1)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
private async doExport(
|
||||
logs: ReadableLogRecord[],
|
||||
resultCallback: (result: ExportResult) => void,
|
||||
): Promise<void> {
|
||||
try {
|
||||
// Filter for event logs only (by scope name)
|
||||
const eventLogs = logs.filter(
|
||||
log =>
|
||||
log.instrumentationScope?.name === 'com.anthropic.claude_code.events',
|
||||
)
|
||||
|
||||
if (eventLogs.length === 0) {
|
||||
resultCallback({ code: ExportResultCode.SUCCESS })
|
||||
return
|
||||
}
|
||||
|
||||
// Transform new logs (failed events are retried independently via backoff)
|
||||
const events = this.transformLogsToEvents(eventLogs).events
|
||||
|
||||
if (events.length === 0) {
|
||||
resultCallback({ code: ExportResultCode.SUCCESS })
|
||||
return
|
||||
}
|
||||
|
||||
if (this.attempts >= this.maxAttempts) {
|
||||
resultCallback({
|
||||
code: ExportResultCode.FAILED,
|
||||
error: new Error(
|
||||
`Dropped ${events.length} events: max attempts (${this.maxAttempts}) reached`,
|
||||
),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Send events
|
||||
const failedEvents = await this.sendEventsInBatches(events)
|
||||
this.attempts++
|
||||
|
||||
if (failedEvents.length > 0) {
|
||||
await this.queueFailedEvents(failedEvents)
|
||||
this.scheduleBackoffRetry()
|
||||
const context = this.lastExportErrorContext
|
||||
? ` (${this.lastExportErrorContext})`
|
||||
: ''
|
||||
resultCallback({
|
||||
code: ExportResultCode.FAILED,
|
||||
error: new Error(
|
||||
`Failed to export ${failedEvents.length} events${context}`,
|
||||
),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Success - reset backoff and immediately retry any queued events
|
||||
this.resetBackoff()
|
||||
if ((await this.getQueuedEventCount()) > 0 && !this.isRetrying) {
|
||||
void this.retryFailedEvents()
|
||||
}
|
||||
resultCallback({ code: ExportResultCode.SUCCESS })
|
||||
} catch (error) {
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
logForDebugging(
|
||||
`1P event logging export failed: ${errorMessage(error)}`,
|
||||
)
|
||||
}
|
||||
logError(error)
|
||||
resultCallback({
|
||||
code: ExportResultCode.FAILED,
|
||||
error: toError(error),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
private async sendEventsInBatches(
|
||||
events: FirstPartyEventLoggingEvent[],
|
||||
): Promise<FirstPartyEventLoggingEvent[]> {
|
||||
// Chunk events into batches
|
||||
const batches: FirstPartyEventLoggingEvent[][] = []
|
||||
for (let i = 0; i < events.length; i += this.maxBatchSize) {
|
||||
batches.push(events.slice(i, i + this.maxBatchSize))
|
||||
}
|
||||
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
logForDebugging(
|
||||
`1P event logging: exporting ${events.length} events in ${batches.length} batch(es)`,
|
||||
)
|
||||
}
|
||||
|
||||
// Send each batch with delay between them. On first failure, assume the
|
||||
// endpoint is down and short-circuit: queue the failed batch plus all
|
||||
// remaining unsent batches without POSTing them. The backoff retry will
|
||||
// probe again with a single batch next tick.
|
||||
const failedBatchEvents: FirstPartyEventLoggingEvent[] = []
|
||||
let lastErrorContext: string | undefined
|
||||
for (let i = 0; i < batches.length; i++) {
|
||||
const batch = batches[i]!
|
||||
try {
|
||||
await this.sendBatchWithRetry({ events: batch })
|
||||
} catch (error) {
|
||||
lastErrorContext = getAxiosErrorContext(error)
|
||||
for (let j = i; j < batches.length; j++) {
|
||||
failedBatchEvents.push(...batches[j]!)
|
||||
}
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
const skipped = batches.length - 1 - i
|
||||
logForDebugging(
|
||||
`1P event logging: batch ${i + 1}/${batches.length} failed (${lastErrorContext}); short-circuiting ${skipped} remaining batch(es)`,
|
||||
)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if (i < batches.length - 1 && this.batchDelayMs > 0) {
|
||||
await sleep(this.batchDelayMs)
|
||||
}
|
||||
}
|
||||
|
||||
if (failedBatchEvents.length > 0 && lastErrorContext) {
|
||||
this.lastExportErrorContext = lastErrorContext
|
||||
}
|
||||
|
||||
return failedBatchEvents
|
||||
}
|
||||
|
||||
private async queueFailedEvents(
|
||||
events: FirstPartyEventLoggingEvent[],
|
||||
): Promise<void> {
|
||||
const filePath = this.getCurrentBatchFilePath()
|
||||
|
||||
// Append-only: just add new events to file (atomic on most filesystems)
|
||||
await this.appendEventsToFile(filePath, events)
|
||||
|
||||
const context = this.lastExportErrorContext
|
||||
? ` (${this.lastExportErrorContext})`
|
||||
: ''
|
||||
const message = `1P event logging: ${events.length} events failed to export${context}`
|
||||
logError(new Error(message))
|
||||
}
|
||||
|
||||
private scheduleBackoffRetry(): void {
|
||||
// Don't schedule if already retrying or shutdown
|
||||
if (this.cancelBackoff || this.isRetrying || this.isShutdown) {
|
||||
return
|
||||
}
|
||||
|
||||
// Quadratic backoff (matching Statsig SDK): base * attempts²
|
||||
const delay = Math.min(
|
||||
this.baseBackoffDelayMs * this.attempts * this.attempts,
|
||||
this.maxBackoffDelayMs,
|
||||
)
|
||||
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
logForDebugging(
|
||||
`1P event logging: scheduling backoff retry in ${delay}ms (attempt ${this.attempts})`,
|
||||
)
|
||||
}
|
||||
|
||||
this.cancelBackoff = this.schedule(async () => {
|
||||
this.cancelBackoff = null
|
||||
await this.retryFailedEvents()
|
||||
}, delay)
|
||||
}
|
||||
|
||||
private async retryFailedEvents(): Promise<void> {
|
||||
const filePath = this.getCurrentBatchFilePath()
|
||||
|
||||
// Keep retrying while there are events and endpoint is healthy
|
||||
while (!this.isShutdown) {
|
||||
const events = await this.loadEventsFromFile(filePath)
|
||||
if (events.length === 0) break
|
||||
|
||||
if (this.attempts >= this.maxAttempts) {
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
logForDebugging(
|
||||
`1P event logging: max attempts (${this.maxAttempts}) reached, dropping ${events.length} events`,
|
||||
)
|
||||
}
|
||||
await this.deleteFile(filePath)
|
||||
this.resetBackoff()
|
||||
return
|
||||
}
|
||||
|
||||
this.isRetrying = true
|
||||
|
||||
// Clear file before retry (we have events in memory now)
|
||||
await this.deleteFile(filePath)
|
||||
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
logForDebugging(
|
||||
`1P event logging: retrying ${events.length} failed events (attempt ${this.attempts + 1})`,
|
||||
)
|
||||
}
|
||||
|
||||
const failedEvents = await this.sendEventsInBatches(events)
|
||||
this.attempts++
|
||||
|
||||
this.isRetrying = false
|
||||
|
||||
if (failedEvents.length > 0) {
|
||||
// Write failures back to disk
|
||||
await this.saveEventsToFile(filePath, failedEvents)
|
||||
this.scheduleBackoffRetry()
|
||||
return // Failed - wait for backoff
|
||||
}
|
||||
|
||||
// Success - reset backoff and continue loop to drain any newly queued events
|
||||
this.resetBackoff()
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
logForDebugging('1P event logging: backoff retry succeeded')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private resetBackoff(): void {
|
||||
this.attempts = 0
|
||||
if (this.cancelBackoff) {
|
||||
this.cancelBackoff()
|
||||
this.cancelBackoff = null
|
||||
}
|
||||
}
|
||||
|
||||
private async sendBatchWithRetry(
|
||||
payload: FirstPartyEventLoggingPayload,
|
||||
): Promise<void> {
|
||||
if (this.isKilled()) {
|
||||
// Throw so the caller short-circuits remaining batches and queues
|
||||
// everything to disk. Zero network traffic while killed; the backoff
|
||||
// timer keeps ticking and will resume POSTs as soon as the GrowthBook
|
||||
// cache picks up the cleared flag.
|
||||
throw new Error('firstParty sink killswitch active')
|
||||
}
|
||||
|
||||
const baseHeaders: Record<string, string> = {
|
||||
'Content-Type': 'application/json',
|
||||
'User-Agent': getClaudeCodeUserAgent(),
|
||||
'x-service-name': 'claude-code',
|
||||
}
|
||||
|
||||
// Skip auth if trust hasn't been established yet
|
||||
// This prevents executing apiKeyHelper commands before the trust dialog
|
||||
// Non-interactive sessions implicitly have workspace trust
|
||||
const hasTrust =
|
||||
checkHasTrustDialogAccepted() || getIsNonInteractiveSession()
|
||||
if (process.env.USER_TYPE === 'ant' && !hasTrust) {
|
||||
logForDebugging('1P event logging: Trust not accepted')
|
||||
}
|
||||
|
||||
// Skip auth when the OAuth token is expired or lacks user:profile
|
||||
// scope (service key sessions). Falls through to unauthenticated send.
|
||||
let shouldSkipAuth = this.skipAuth || !hasTrust
|
||||
if (!shouldSkipAuth && isClaudeAISubscriber()) {
|
||||
const tokens = getClaudeAIOAuthTokens()
|
||||
if (!hasProfileScope()) {
|
||||
shouldSkipAuth = true
|
||||
} else if (tokens && isOAuthTokenExpired(tokens.expiresAt)) {
|
||||
shouldSkipAuth = true
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
logForDebugging(
|
||||
'1P event logging: OAuth token expired, skipping auth to avoid 401',
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try with auth headers first (unless trust not established or token is known to be expired)
|
||||
const authResult = shouldSkipAuth
|
||||
? { headers: {}, error: 'trust not established or Oauth token expired' }
|
||||
: getAuthHeaders()
|
||||
const useAuth = !authResult.error
|
||||
|
||||
if (!useAuth && process.env.USER_TYPE === 'ant') {
|
||||
logForDebugging(
|
||||
`1P event logging: auth not available, sending without auth`,
|
||||
)
|
||||
}
|
||||
|
||||
const headers = useAuth
|
||||
? { ...baseHeaders, ...authResult.headers }
|
||||
: baseHeaders
|
||||
|
||||
try {
|
||||
const response = await axios.post(this.endpoint, payload, {
|
||||
timeout: this.timeout,
|
||||
headers,
|
||||
})
|
||||
this.logSuccess(payload.events.length, useAuth, response.data)
|
||||
return
|
||||
} catch (error) {
|
||||
// Handle 401 by retrying without auth
|
||||
if (
|
||||
useAuth &&
|
||||
axios.isAxiosError(error) &&
|
||||
error.response?.status === 401
|
||||
) {
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
logForDebugging(
|
||||
'1P event logging: 401 auth error, retrying without auth',
|
||||
)
|
||||
}
|
||||
const response = await axios.post(this.endpoint, payload, {
|
||||
timeout: this.timeout,
|
||||
headers: baseHeaders,
|
||||
})
|
||||
this.logSuccess(payload.events.length, false, response.data)
|
||||
return
|
||||
}
|
||||
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
private logSuccess(
|
||||
eventCount: number,
|
||||
withAuth: boolean,
|
||||
responseData: unknown,
|
||||
): void {
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
logForDebugging(
|
||||
`1P event logging: ${eventCount} events exported successfully${withAuth ? ' (with auth)' : ' (without auth)'}`,
|
||||
)
|
||||
logForDebugging(`API Response: ${jsonStringify(responseData, null, 2)}`)
|
||||
}
|
||||
}
|
||||
|
||||
private hrTimeToDate(hrTime: HrTime): Date {
|
||||
const [seconds, nanoseconds] = hrTime
|
||||
return new Date(seconds * 1000 + nanoseconds / 1000000)
|
||||
}
|
||||
|
||||
private transformLogsToEvents(
|
||||
logs: ReadableLogRecord[],
|
||||
): FirstPartyEventLoggingPayload {
|
||||
const events: FirstPartyEventLoggingEvent[] = []
|
||||
|
||||
for (const log of logs) {
|
||||
const attributes = log.attributes || {}
|
||||
|
||||
// Check if this is a GrowthBook experiment event
|
||||
if (attributes.event_type === 'GrowthbookExperimentEvent') {
|
||||
const timestamp = this.hrTimeToDate(log.hrTime)
|
||||
const account_uuid = attributes.account_uuid as string | undefined
|
||||
const organization_uuid = attributes.organization_uuid as
|
||||
| string
|
||||
| undefined
|
||||
events.push({
|
||||
event_type: 'GrowthbookExperimentEvent',
|
||||
event_data: GrowthbookExperimentEvent.toJSON({
|
||||
event_id: attributes.event_id as string,
|
||||
timestamp,
|
||||
experiment_id: attributes.experiment_id as string,
|
||||
variation_id: attributes.variation_id as number,
|
||||
environment: attributes.environment as string,
|
||||
user_attributes: attributes.user_attributes as string,
|
||||
experiment_metadata: attributes.experiment_metadata as string,
|
||||
device_id: attributes.device_id as string,
|
||||
session_id: attributes.session_id as string,
|
||||
auth:
|
||||
account_uuid || organization_uuid
|
||||
? { account_uuid, organization_uuid }
|
||||
: undefined,
|
||||
}),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract event name
|
||||
const eventName =
|
||||
(attributes.event_name as string) || (log.body as string) || 'unknown'
|
||||
|
||||
// Extract metadata objects directly (no JSON parsing needed)
|
||||
const coreMetadata = attributes.core_metadata as EventMetadata | undefined
|
||||
const userMetadata = attributes.user_metadata as CoreUserData
|
||||
const eventMetadata = (attributes.event_metadata || {}) as Record<
|
||||
string,
|
||||
unknown
|
||||
>
|
||||
|
||||
if (!coreMetadata) {
|
||||
// Emit partial event if core metadata is missing
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
logForDebugging(
|
||||
`1P event logging: core_metadata missing for event ${eventName}`,
|
||||
)
|
||||
}
|
||||
events.push({
|
||||
event_type: 'ClaudeCodeInternalEvent',
|
||||
event_data: ClaudeCodeInternalEvent.toJSON({
|
||||
event_id: attributes.event_id as string | undefined,
|
||||
event_name: eventName,
|
||||
client_timestamp: this.hrTimeToDate(log.hrTime),
|
||||
session_id: getSessionId(),
|
||||
additional_metadata: Buffer.from(
|
||||
jsonStringify({
|
||||
transform_error: 'core_metadata attribute is missing',
|
||||
}),
|
||||
).toString('base64'),
|
||||
}),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Transform to 1P format
|
||||
const formatted = to1PEventFormat(
|
||||
coreMetadata,
|
||||
userMetadata,
|
||||
eventMetadata,
|
||||
)
|
||||
|
||||
// _PROTO_* keys are PII-tagged values meant only for privileged BQ
|
||||
// columns. Hoist known keys to proto fields, then defensively strip any
|
||||
// remaining _PROTO_* so an unrecognized future key can't silently land
|
||||
// in the general-access additional_metadata blob. sink.ts applies the
|
||||
// same strip before Datadog; this closes the 1P side.
|
||||
const {
|
||||
_PROTO_skill_name,
|
||||
_PROTO_plugin_name,
|
||||
_PROTO_marketplace_name,
|
||||
...rest
|
||||
} = formatted.additional
|
||||
const additionalMetadata = stripProtoFields(rest)
|
||||
|
||||
events.push({
|
||||
event_type: 'ClaudeCodeInternalEvent',
|
||||
event_data: ClaudeCodeInternalEvent.toJSON({
|
||||
event_id: attributes.event_id as string | undefined,
|
||||
event_name: eventName,
|
||||
client_timestamp: this.hrTimeToDate(log.hrTime),
|
||||
device_id: attributes.user_id as string | undefined,
|
||||
email: userMetadata?.email,
|
||||
auth: formatted.auth,
|
||||
...formatted.core,
|
||||
env: formatted.env,
|
||||
process: formatted.process,
|
||||
skill_name:
|
||||
typeof _PROTO_skill_name === 'string'
|
||||
? _PROTO_skill_name
|
||||
: undefined,
|
||||
plugin_name:
|
||||
typeof _PROTO_plugin_name === 'string'
|
||||
? _PROTO_plugin_name
|
||||
: undefined,
|
||||
marketplace_name:
|
||||
typeof _PROTO_marketplace_name === 'string'
|
||||
? _PROTO_marketplace_name
|
||||
: undefined,
|
||||
additional_metadata:
|
||||
Object.keys(additionalMetadata).length > 0
|
||||
? Buffer.from(jsonStringify(additionalMetadata)).toString(
|
||||
'base64',
|
||||
)
|
||||
: undefined,
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
return { events }
|
||||
}
|
||||
|
||||
async shutdown(): Promise<void> {
|
||||
this.isShutdown = true
|
||||
this.resetBackoff()
|
||||
await this.forceFlush()
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
logForDebugging('1P event logging exporter shutdown complete')
|
||||
}
|
||||
}
|
||||
|
||||
async forceFlush(): Promise<void> {
|
||||
await Promise.all(this.pendingExports)
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
logForDebugging('1P event logging exporter flush complete')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function getAxiosErrorContext(error: unknown): string {
|
||||
if (!axios.isAxiosError(error)) {
|
||||
return errorMessage(error)
|
||||
}
|
||||
|
||||
const parts: string[] = []
|
||||
|
||||
const requestId = error.response?.headers?.['request-id']
|
||||
if (requestId) {
|
||||
parts.push(`request-id=${requestId}`)
|
||||
}
|
||||
|
||||
if (error.response?.status) {
|
||||
parts.push(`status=${error.response.status}`)
|
||||
}
|
||||
|
||||
if (error.code) {
|
||||
parts.push(`code=${error.code}`)
|
||||
}
|
||||
|
||||
if (error.message) {
|
||||
parts.push(error.message)
|
||||
}
|
||||
|
||||
return parts.join(', ')
|
||||
}
|
||||
1155
src/services/analytics/growthbook.ts
Normal file
1155
src/services/analytics/growthbook.ts
Normal file
File diff suppressed because it is too large
Load Diff
173
src/services/analytics/index.ts
Normal file
173
src/services/analytics/index.ts
Normal file
@@ -0,0 +1,173 @@
|
||||
/**
|
||||
* Analytics service - public API for event logging
|
||||
*
|
||||
* This module serves as the main entry point for analytics events in Claude CLI.
|
||||
*
|
||||
* DESIGN: This module has NO dependencies to avoid import cycles.
|
||||
* Events are queued until attachAnalyticsSink() is called during app initialization.
|
||||
* The sink handles routing to Datadog and 1P event logging.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Marker type for verifying analytics metadata doesn't contain sensitive data
|
||||
*
|
||||
* This type forces explicit verification that string values being logged
|
||||
* don't contain code snippets, file paths, or other sensitive information.
|
||||
*
|
||||
* Usage: `myString as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS`
|
||||
*/
|
||||
export type AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS = never
|
||||
|
||||
/**
|
||||
* Marker type for values routed to PII-tagged proto columns via `_PROTO_*`
|
||||
* payload keys. The destination BQ column has privileged access controls,
|
||||
* so unredacted values are acceptable — unlike general-access backends.
|
||||
*
|
||||
* sink.ts strips `_PROTO_*` keys before Datadog fanout; only the 1P
|
||||
* exporter (firstPartyEventLoggingExporter) sees them and hoists them to the
|
||||
* top-level proto field. A single stripProtoFields call guards all non-1P
|
||||
* sinks — no per-sink filtering to forget.
|
||||
*
|
||||
* Usage: `rawName as AnalyticsMetadata_I_VERIFIED_THIS_IS_PII_TAGGED`
|
||||
*/
|
||||
export type AnalyticsMetadata_I_VERIFIED_THIS_IS_PII_TAGGED = never
|
||||
|
||||
/**
|
||||
* Strip `_PROTO_*` keys from a payload destined for general-access storage.
|
||||
* Used by:
|
||||
* - sink.ts: before Datadog fanout (never sees PII-tagged values)
|
||||
* - firstPartyEventLoggingExporter: defensive strip of additional_metadata
|
||||
* after hoisting known _PROTO_* keys to proto fields — prevents a future
|
||||
* unrecognized _PROTO_foo from silently landing in the BQ JSON blob.
|
||||
*
|
||||
* Returns the input unchanged (same reference) when no _PROTO_ keys present.
|
||||
*/
|
||||
export function stripProtoFields<V>(
|
||||
metadata: Record<string, V>,
|
||||
): Record<string, V> {
|
||||
let result: Record<string, V> | undefined
|
||||
for (const key in metadata) {
|
||||
if (key.startsWith('_PROTO_')) {
|
||||
if (result === undefined) {
|
||||
result = { ...metadata }
|
||||
}
|
||||
delete result[key]
|
||||
}
|
||||
}
|
||||
return result ?? metadata
|
||||
}
|
||||
|
||||
// Internal type for logEvent metadata - different from the enriched EventMetadata in metadata.ts
|
||||
type LogEventMetadata = { [key: string]: boolean | number | undefined }
|
||||
|
||||
type QueuedEvent = {
|
||||
eventName: string
|
||||
metadata: LogEventMetadata
|
||||
async: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* Sink interface for the analytics backend
|
||||
*/
|
||||
export type AnalyticsSink = {
|
||||
logEvent: (eventName: string, metadata: LogEventMetadata) => void
|
||||
logEventAsync: (
|
||||
eventName: string,
|
||||
metadata: LogEventMetadata,
|
||||
) => Promise<void>
|
||||
}
|
||||
|
||||
// Event queue for events logged before sink is attached
|
||||
const eventQueue: QueuedEvent[] = []
|
||||
|
||||
// Sink - initialized during app startup
|
||||
let sink: AnalyticsSink | null = null
|
||||
|
||||
/**
|
||||
* Attach the analytics sink that will receive all events.
|
||||
* Queued events are drained asynchronously via queueMicrotask to avoid
|
||||
* adding latency to the startup path.
|
||||
*
|
||||
* Idempotent: if a sink is already attached, this is a no-op. This allows
|
||||
* calling from both the preAction hook (for subcommands) and setup() (for
|
||||
* the default command) without coordination.
|
||||
*/
|
||||
export function attachAnalyticsSink(newSink: AnalyticsSink): void {
|
||||
if (sink !== null) {
|
||||
return
|
||||
}
|
||||
sink = newSink
|
||||
|
||||
// Drain the queue asynchronously to avoid blocking startup
|
||||
if (eventQueue.length > 0) {
|
||||
const queuedEvents = [...eventQueue]
|
||||
eventQueue.length = 0
|
||||
|
||||
// Log queue size for ants to help debug analytics initialization timing
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
sink.logEvent('analytics_sink_attached', {
|
||||
queued_event_count: queuedEvents.length,
|
||||
})
|
||||
}
|
||||
|
||||
queueMicrotask(() => {
|
||||
for (const event of queuedEvents) {
|
||||
if (event.async) {
|
||||
void sink!.logEventAsync(event.eventName, event.metadata)
|
||||
} else {
|
||||
sink!.logEvent(event.eventName, event.metadata)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Log an event to analytics backends (synchronous)
|
||||
*
|
||||
* Events may be sampled based on the 'tengu_event_sampling_config' dynamic config.
|
||||
* When sampled, the sample_rate is added to the event metadata.
|
||||
*
|
||||
* If no sink is attached, events are queued and drained when the sink attaches.
|
||||
*/
|
||||
export function logEvent(
|
||||
eventName: string,
|
||||
// intentionally no strings unless AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
// to avoid accidentally logging code/filepaths
|
||||
metadata: LogEventMetadata,
|
||||
): void {
|
||||
if (sink === null) {
|
||||
eventQueue.push({ eventName, metadata, async: false })
|
||||
return
|
||||
}
|
||||
sink.logEvent(eventName, metadata)
|
||||
}
|
||||
|
||||
/**
|
||||
* Log an event to analytics backends (asynchronous)
|
||||
*
|
||||
* Events may be sampled based on the 'tengu_event_sampling_config' dynamic config.
|
||||
* When sampled, the sample_rate is added to the event metadata.
|
||||
*
|
||||
* If no sink is attached, events are queued and drained when the sink attaches.
|
||||
*/
|
||||
export async function logEventAsync(
|
||||
eventName: string,
|
||||
// intentionally no strings, to avoid accidentally logging code/filepaths
|
||||
metadata: LogEventMetadata,
|
||||
): Promise<void> {
|
||||
if (sink === null) {
|
||||
eventQueue.push({ eventName, metadata, async: true })
|
||||
return
|
||||
}
|
||||
await sink.logEventAsync(eventName, metadata)
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset analytics state for testing purposes only.
|
||||
* @internal
|
||||
*/
|
||||
export function _resetForTesting(): void {
|
||||
sink = null
|
||||
eventQueue.length = 0
|
||||
}
|
||||
973
src/services/analytics/metadata.ts
Normal file
973
src/services/analytics/metadata.ts
Normal file
@@ -0,0 +1,973 @@
|
||||
// biome-ignore-all assist/source/organizeImports: ANT-ONLY import markers must not be reordered
|
||||
/**
|
||||
* Shared event metadata enrichment for analytics systems
|
||||
*
|
||||
* This module provides a single source of truth for collecting and formatting
|
||||
* event metadata across all analytics systems (Datadog, 1P).
|
||||
*/
|
||||
|
||||
import { extname } from 'path'
|
||||
import memoize from 'lodash-es/memoize.js'
|
||||
import { env, getHostPlatformForAnalytics } from '../../utils/env.js'
|
||||
import { envDynamic } from '../../utils/envDynamic.js'
|
||||
import { getModelBetas } from '../../utils/betas.js'
|
||||
import { getMainLoopModel } from '../../utils/model/model.js'
|
||||
import {
|
||||
getSessionId,
|
||||
getIsInteractive,
|
||||
getKairosActive,
|
||||
getClientType,
|
||||
getParentSessionId as getParentSessionIdFromState,
|
||||
} from '../../bootstrap/state.js'
|
||||
import { isEnvTruthy } from '../../utils/envUtils.js'
|
||||
import { isOfficialMcpUrl } from '../mcp/officialRegistry.js'
|
||||
import { isClaudeAISubscriber, getSubscriptionType } from '../../utils/auth.js'
|
||||
import { getRepoRemoteHash } from '../../utils/git.js'
|
||||
import {
|
||||
getWslVersion,
|
||||
getLinuxDistroInfo,
|
||||
detectVcs,
|
||||
} from '../../utils/platform.js'
|
||||
import type { CoreUserData } from 'src/utils/user.js'
|
||||
import { getAgentContext } from '../../utils/agentContext.js'
|
||||
import type { EnvironmentMetadata } from '../../types/generated/events_mono/claude_code/v1/claude_code_internal_event.js'
|
||||
import type { PublicApiAuth } from '../../types/generated/events_mono/common/v1/auth.js'
|
||||
import { jsonStringify } from '../../utils/slowOperations.js'
|
||||
import {
|
||||
getAgentId,
|
||||
getParentSessionId as getTeammateParentSessionId,
|
||||
getTeamName,
|
||||
isTeammate,
|
||||
} from '../../utils/teammate.js'
|
||||
import { feature } from 'bun:bundle'
|
||||
|
||||
/**
|
||||
* Marker type for verifying analytics metadata doesn't contain sensitive data
|
||||
*
|
||||
* This type forces explicit verification that string values being logged
|
||||
* don't contain code snippets, file paths, or other sensitive information.
|
||||
*
|
||||
* The metadata is expected to be JSON-serializable.
|
||||
*
|
||||
* Usage: `myString as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS`
|
||||
*
|
||||
* The type is `never` which means it can never actually hold a value - this is
|
||||
* intentional as it's only used for type-casting to document developer intent.
|
||||
*/
|
||||
export type AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS = never
|
||||
|
||||
/**
|
||||
* Sanitizes tool names for analytics logging to avoid PII exposure.
|
||||
*
|
||||
* MCP tool names follow the format `mcp__<server>__<tool>` and can reveal
|
||||
* user-specific server configurations, which is considered PII-medium.
|
||||
* This function redacts MCP tool names while preserving built-in tool names
|
||||
* (Bash, Read, Write, etc.) which are safe to log.
|
||||
*
|
||||
* @param toolName - The tool name to sanitize
|
||||
* @returns The original name for built-in tools, or 'mcp_tool' for MCP tools
|
||||
*/
|
||||
export function sanitizeToolNameForAnalytics(
|
||||
toolName: string,
|
||||
): AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS {
|
||||
if (toolName.startsWith('mcp__')) {
|
||||
return 'mcp_tool' as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS
|
||||
}
|
||||
return toolName as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if detailed tool name logging is enabled for OTLP events.
|
||||
* When enabled, MCP server/tool names and Skill names are logged.
|
||||
* Disabled by default to protect PII (user-specific server configurations).
|
||||
*
|
||||
* Enable with OTEL_LOG_TOOL_DETAILS=1
|
||||
*/
|
||||
export function isToolDetailsLoggingEnabled(): boolean {
|
||||
return isEnvTruthy(process.env.OTEL_LOG_TOOL_DETAILS)
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if detailed tool name logging (MCP server/tool names) is enabled
|
||||
* for analytics events.
|
||||
*
|
||||
* Per go/taxonomy, MCP names are medium PII. We log them for:
|
||||
* - Cowork (entrypoint=local-agent) — no ZDR concept, log all MCPs
|
||||
* - claude.ai-proxied connectors — always official (from claude.ai's list)
|
||||
* - Servers whose URL matches the official MCP registry — directory
|
||||
* connectors added via `claude mcp add`, not customer-specific config
|
||||
*
|
||||
* Custom/user-configured MCPs stay sanitized (toolName='mcp_tool').
|
||||
*/
|
||||
export function isAnalyticsToolDetailsLoggingEnabled(
|
||||
mcpServerType: string | undefined,
|
||||
mcpServerBaseUrl: string | undefined,
|
||||
): boolean {
|
||||
if (process.env.CLAUDE_CODE_ENTRYPOINT === 'local-agent') {
|
||||
return true
|
||||
}
|
||||
if (mcpServerType === 'claudeai-proxy') {
|
||||
return true
|
||||
}
|
||||
if (mcpServerBaseUrl && isOfficialMcpUrl(mcpServerBaseUrl)) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
/**
|
||||
* Built-in first-party MCP servers whose names are fixed reserved strings,
|
||||
* not user-configured — so logging them is not PII. Checked in addition to
|
||||
* isAnalyticsToolDetailsLoggingEnabled's transport/URL gates, which a stdio
|
||||
* built-in would otherwise fail.
|
||||
*
|
||||
* Feature-gated so the set is empty when the feature is off: the name
|
||||
* reservation (main.tsx, config.ts addMcpServer) is itself feature-gated, so
|
||||
* a user-configured 'computer-use' is possible in builds without the feature.
|
||||
*/
|
||||
/* eslint-disable @typescript-eslint/no-require-imports */
|
||||
const BUILTIN_MCP_SERVER_NAMES: ReadonlySet<string> = new Set(
|
||||
feature('CHICAGO_MCP')
|
||||
? [
|
||||
(
|
||||
require('../../utils/computerUse/common.js') as typeof import('../../utils/computerUse/common.js')
|
||||
).COMPUTER_USE_MCP_SERVER_NAME,
|
||||
]
|
||||
: [],
|
||||
)
|
||||
/* eslint-enable @typescript-eslint/no-require-imports */
|
||||
|
||||
/**
|
||||
* Spreadable helper for logEvent payloads — returns {mcpServerName, mcpToolName}
|
||||
* if the gate passes, empty object otherwise. Consolidates the identical IIFE
|
||||
* pattern at each tengu_tool_use_* call site.
|
||||
*/
|
||||
export function mcpToolDetailsForAnalytics(
|
||||
toolName: string,
|
||||
mcpServerType: string | undefined,
|
||||
mcpServerBaseUrl: string | undefined,
|
||||
): {
|
||||
mcpServerName?: AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS
|
||||
mcpToolName?: AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS
|
||||
} {
|
||||
const details = extractMcpToolDetails(toolName)
|
||||
if (!details) {
|
||||
return {}
|
||||
}
|
||||
if (
|
||||
!BUILTIN_MCP_SERVER_NAMES.has(details.serverName) &&
|
||||
!isAnalyticsToolDetailsLoggingEnabled(mcpServerType, mcpServerBaseUrl)
|
||||
) {
|
||||
return {}
|
||||
}
|
||||
return {
|
||||
mcpServerName: details.serverName,
|
||||
mcpToolName: details.mcpToolName,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract MCP server and tool names from a full MCP tool name.
|
||||
* MCP tool names follow the format: mcp__<server>__<tool>
|
||||
*
|
||||
* @param toolName - The full tool name (e.g., 'mcp__slack__read_channel')
|
||||
* @returns Object with serverName and toolName, or undefined if not an MCP tool
|
||||
*/
|
||||
export function extractMcpToolDetails(toolName: string):
|
||||
| {
|
||||
serverName: AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS
|
||||
mcpToolName: AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS
|
||||
}
|
||||
| undefined {
|
||||
if (!toolName.startsWith('mcp__')) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
// Format: mcp__<server>__<tool>
|
||||
const parts = toolName.split('__')
|
||||
if (parts.length < 3) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const serverName = parts[1]
|
||||
// Tool name may contain __ so rejoin remaining parts
|
||||
const mcpToolName = parts.slice(2).join('__')
|
||||
|
||||
if (!serverName || !mcpToolName) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
return {
|
||||
serverName:
|
||||
serverName as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
mcpToolName:
|
||||
mcpToolName as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract skill name from Skill tool input.
|
||||
*
|
||||
* @param toolName - The tool name (should be 'Skill')
|
||||
* @param input - The tool input containing the skill name
|
||||
* @returns The skill name if this is a Skill tool call, undefined otherwise
|
||||
*/
|
||||
export function extractSkillName(
|
||||
toolName: string,
|
||||
input: unknown,
|
||||
): AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS | undefined {
|
||||
if (toolName !== 'Skill') {
|
||||
return undefined
|
||||
}
|
||||
|
||||
if (
|
||||
typeof input === 'object' &&
|
||||
input !== null &&
|
||||
'skill' in input &&
|
||||
typeof (input as { skill: unknown }).skill === 'string'
|
||||
) {
|
||||
return (input as { skill: string })
|
||||
.skill as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS
|
||||
}
|
||||
|
||||
return undefined
|
||||
}
|
||||
|
||||
const TOOL_INPUT_STRING_TRUNCATE_AT = 512
|
||||
const TOOL_INPUT_STRING_TRUNCATE_TO = 128
|
||||
const TOOL_INPUT_MAX_JSON_CHARS = 4 * 1024
|
||||
const TOOL_INPUT_MAX_COLLECTION_ITEMS = 20
|
||||
const TOOL_INPUT_MAX_DEPTH = 2
|
||||
|
||||
function truncateToolInputValue(value: unknown, depth = 0): unknown {
|
||||
if (typeof value === 'string') {
|
||||
if (value.length > TOOL_INPUT_STRING_TRUNCATE_AT) {
|
||||
return `${value.slice(0, TOOL_INPUT_STRING_TRUNCATE_TO)}…[${value.length} chars]`
|
||||
}
|
||||
return value
|
||||
}
|
||||
if (
|
||||
typeof value === 'number' ||
|
||||
typeof value === 'boolean' ||
|
||||
value === null ||
|
||||
value === undefined
|
||||
) {
|
||||
return value
|
||||
}
|
||||
if (depth >= TOOL_INPUT_MAX_DEPTH) {
|
||||
return '<nested>'
|
||||
}
|
||||
if (Array.isArray(value)) {
|
||||
const mapped = value
|
||||
.slice(0, TOOL_INPUT_MAX_COLLECTION_ITEMS)
|
||||
.map(v => truncateToolInputValue(v, depth + 1))
|
||||
if (value.length > TOOL_INPUT_MAX_COLLECTION_ITEMS) {
|
||||
mapped.push(`…[${value.length} items]`)
|
||||
}
|
||||
return mapped
|
||||
}
|
||||
if (typeof value === 'object') {
|
||||
const entries = Object.entries(value as Record<string, unknown>)
|
||||
// Skip internal marker keys (e.g. _simulatedSedEdit re-introduced by
|
||||
// SedEditPermissionRequest) so they don't leak into telemetry.
|
||||
.filter(([k]) => !k.startsWith('_'))
|
||||
const mapped = entries
|
||||
.slice(0, TOOL_INPUT_MAX_COLLECTION_ITEMS)
|
||||
.map(([k, v]) => [k, truncateToolInputValue(v, depth + 1)])
|
||||
if (entries.length > TOOL_INPUT_MAX_COLLECTION_ITEMS) {
|
||||
mapped.push(['…', `${entries.length} keys`])
|
||||
}
|
||||
return Object.fromEntries(mapped)
|
||||
}
|
||||
return String(value)
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialize a tool's input arguments for the OTel tool_result event.
|
||||
* Truncates long strings and deep nesting to keep the output bounded while
|
||||
* preserving forensically useful fields like file paths, URLs, and MCP args.
|
||||
* Returns undefined when OTEL_LOG_TOOL_DETAILS is not enabled.
|
||||
*/
|
||||
export function extractToolInputForTelemetry(
|
||||
input: unknown,
|
||||
): string | undefined {
|
||||
if (!isToolDetailsLoggingEnabled()) {
|
||||
return undefined
|
||||
}
|
||||
const truncated = truncateToolInputValue(input)
|
||||
let json = jsonStringify(truncated)
|
||||
if (json.length > TOOL_INPUT_MAX_JSON_CHARS) {
|
||||
json = json.slice(0, TOOL_INPUT_MAX_JSON_CHARS) + '…[truncated]'
|
||||
}
|
||||
return json
|
||||
}
|
||||
|
||||
/**
|
||||
* Maximum length for file extensions to be logged.
|
||||
* Extensions longer than this are considered potentially sensitive
|
||||
* (e.g., hash-based filenames like "key-hash-abcd-123-456") and
|
||||
* will be replaced with 'other'.
|
||||
*/
|
||||
const MAX_FILE_EXTENSION_LENGTH = 10
|
||||
|
||||
/**
|
||||
* Extracts and sanitizes a file extension for analytics logging.
|
||||
*
|
||||
* Uses Node's path.extname for reliable cross-platform extension extraction.
|
||||
* Returns 'other' for extensions exceeding MAX_FILE_EXTENSION_LENGTH to avoid
|
||||
* logging potentially sensitive data (like hash-based filenames).
|
||||
*
|
||||
* @param filePath - The file path to extract the extension from
|
||||
* @returns The sanitized extension, 'other' for long extensions, or undefined if no extension
|
||||
*/
|
||||
export function getFileExtensionForAnalytics(
|
||||
filePath: string,
|
||||
): AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS | undefined {
|
||||
const ext = extname(filePath).toLowerCase()
|
||||
if (!ext || ext === '.') {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const extension = ext.slice(1) // remove leading dot
|
||||
if (extension.length > MAX_FILE_EXTENSION_LENGTH) {
|
||||
return 'other' as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS
|
||||
}
|
||||
|
||||
return extension as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS
|
||||
}
|
||||
|
||||
/** Allow list of commands we extract file extensions from. */
|
||||
const FILE_COMMANDS = new Set([
|
||||
'rm',
|
||||
'mv',
|
||||
'cp',
|
||||
'touch',
|
||||
'mkdir',
|
||||
'chmod',
|
||||
'chown',
|
||||
'cat',
|
||||
'head',
|
||||
'tail',
|
||||
'sort',
|
||||
'stat',
|
||||
'diff',
|
||||
'wc',
|
||||
'grep',
|
||||
'rg',
|
||||
'sed',
|
||||
])
|
||||
|
||||
/** Regex to split bash commands on compound operators (&&, ||, ;, |). */
|
||||
const COMPOUND_OPERATOR_REGEX = /\s*(?:&&|\|\||[;|])\s*/
|
||||
|
||||
/** Regex to split on whitespace. */
|
||||
const WHITESPACE_REGEX = /\s+/
|
||||
|
||||
/**
|
||||
* Extracts file extensions from a bash command for analytics.
|
||||
* Best-effort: splits on operators and whitespace, extracts extensions
|
||||
* from non-flag args of allowed commands. No heavy shell parsing needed
|
||||
* because grep patterns and sed scripts rarely resemble file extensions.
|
||||
*/
|
||||
export function getFileExtensionsFromBashCommand(
|
||||
command: string,
|
||||
simulatedSedEditFilePath?: string,
|
||||
): AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS | undefined {
|
||||
if (!command.includes('.') && !simulatedSedEditFilePath) return undefined
|
||||
|
||||
let result: string | undefined
|
||||
const seen = new Set<string>()
|
||||
|
||||
if (simulatedSedEditFilePath) {
|
||||
const ext = getFileExtensionForAnalytics(simulatedSedEditFilePath)
|
||||
if (ext) {
|
||||
seen.add(ext)
|
||||
result = ext
|
||||
}
|
||||
}
|
||||
|
||||
for (const subcmd of command.split(COMPOUND_OPERATOR_REGEX)) {
|
||||
if (!subcmd) continue
|
||||
const tokens = subcmd.split(WHITESPACE_REGEX)
|
||||
if (tokens.length < 2) continue
|
||||
|
||||
const firstToken = tokens[0]!
|
||||
const slashIdx = firstToken.lastIndexOf('/')
|
||||
const baseCmd = slashIdx >= 0 ? firstToken.slice(slashIdx + 1) : firstToken
|
||||
if (!FILE_COMMANDS.has(baseCmd)) continue
|
||||
|
||||
for (let i = 1; i < tokens.length; i++) {
|
||||
const arg = tokens[i]!
|
||||
if (arg.charCodeAt(0) === 45 /* - */) continue
|
||||
const ext = getFileExtensionForAnalytics(arg)
|
||||
if (ext && !seen.has(ext)) {
|
||||
seen.add(ext)
|
||||
result = result ? result + ',' + ext : ext
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!result) return undefined
|
||||
return result as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS
|
||||
}
|
||||
|
||||
/**
|
||||
* Environment context metadata
|
||||
*/
|
||||
export type EnvContext = {
|
||||
platform: string
|
||||
platformRaw: string
|
||||
arch: string
|
||||
nodeVersion: string
|
||||
terminal: string | null
|
||||
packageManagers: string
|
||||
runtimes: string
|
||||
isRunningWithBun: boolean
|
||||
isCi: boolean
|
||||
isClaubbit: boolean
|
||||
isClaudeCodeRemote: boolean
|
||||
isLocalAgentMode: boolean
|
||||
isConductor: boolean
|
||||
remoteEnvironmentType?: string
|
||||
coworkerType?: string
|
||||
claudeCodeContainerId?: string
|
||||
claudeCodeRemoteSessionId?: string
|
||||
tags?: string
|
||||
isGithubAction: boolean
|
||||
isClaudeCodeAction: boolean
|
||||
isClaudeAiAuth: boolean
|
||||
version: string
|
||||
versionBase?: string
|
||||
buildTime: string
|
||||
deploymentEnvironment: string
|
||||
githubEventName?: string
|
||||
githubActionsRunnerEnvironment?: string
|
||||
githubActionsRunnerOs?: string
|
||||
githubActionRef?: string
|
||||
wslVersion?: string
|
||||
linuxDistroId?: string
|
||||
linuxDistroVersion?: string
|
||||
linuxKernel?: string
|
||||
vcs?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Process metrics included with all analytics events.
|
||||
*/
|
||||
export type ProcessMetrics = {
|
||||
uptime: number
|
||||
rss: number
|
||||
heapTotal: number
|
||||
heapUsed: number
|
||||
external: number
|
||||
arrayBuffers: number
|
||||
constrainedMemory: number | undefined
|
||||
cpuUsage: NodeJS.CpuUsage
|
||||
cpuPercent: number | undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* Core event metadata shared across all analytics systems
|
||||
*/
|
||||
export type EventMetadata = {
|
||||
model: string
|
||||
sessionId: string
|
||||
userType: string
|
||||
betas?: string
|
||||
envContext: EnvContext
|
||||
entrypoint?: string
|
||||
agentSdkVersion?: string
|
||||
isInteractive: string
|
||||
clientType: string
|
||||
processMetrics?: ProcessMetrics
|
||||
sweBenchRunId: string
|
||||
sweBenchInstanceId: string
|
||||
sweBenchTaskId: string
|
||||
// Swarm/team agent identification for analytics attribution
|
||||
agentId?: string // CLAUDE_CODE_AGENT_ID (format: agentName@teamName) or subagent UUID
|
||||
parentSessionId?: string // CLAUDE_CODE_PARENT_SESSION_ID (team lead's session)
|
||||
agentType?: 'teammate' | 'subagent' | 'standalone' // Distinguishes swarm teammates, Agent tool subagents, and standalone agents
|
||||
teamName?: string // Team name for swarm agents (from env var or AsyncLocalStorage)
|
||||
subscriptionType?: string // OAuth subscription tier (max, pro, enterprise, team)
|
||||
rh?: string // Hashed repo remote URL (first 16 chars of SHA256), for joining with server-side data
|
||||
kairosActive?: true // KAIROS assistant mode active (ant-only; set in main.tsx after gate check)
|
||||
skillMode?: 'discovery' | 'coach' | 'discovery_and_coach' // Which skill surfacing mechanism(s) are gated on (ant-only; for BQ session segmentation)
|
||||
observerMode?: 'backseat' | 'skillcoach' | 'both' // Which observer classifiers are gated on (ant-only; for BQ cohort splits on tengu_backseat_* events)
|
||||
}
|
||||
|
||||
/**
|
||||
* Options for enriching event metadata
|
||||
*/
|
||||
export type EnrichMetadataOptions = {
|
||||
// Model to use, falls back to getMainLoopModel() if not provided
|
||||
model?: unknown
|
||||
// Explicit betas string (already joined)
|
||||
betas?: unknown
|
||||
// Additional metadata to include (optional)
|
||||
additionalMetadata?: Record<string, unknown>
|
||||
}
|
||||
|
||||
/**
|
||||
* Get agent identification for analytics.
|
||||
* Priority: AsyncLocalStorage context (subagents) > env vars (swarm teammates)
|
||||
*/
|
||||
function getAgentIdentification(): {
|
||||
agentId?: string
|
||||
parentSessionId?: string
|
||||
agentType?: 'teammate' | 'subagent' | 'standalone'
|
||||
teamName?: string
|
||||
} {
|
||||
// Check AsyncLocalStorage first (for subagents running in same process)
|
||||
const agentContext = getAgentContext()
|
||||
if (agentContext) {
|
||||
const result: ReturnType<typeof getAgentIdentification> = {
|
||||
agentId: agentContext.agentId,
|
||||
parentSessionId: agentContext.parentSessionId,
|
||||
agentType: agentContext.agentType,
|
||||
}
|
||||
if (agentContext.agentType === 'teammate') {
|
||||
result.teamName = agentContext.teamName
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Fall back to swarm helpers (for swarm agents)
|
||||
const agentId = getAgentId()
|
||||
const parentSessionId = getTeammateParentSessionId()
|
||||
const teamName = getTeamName()
|
||||
const isSwarmAgent = isTeammate()
|
||||
// For standalone agents (have agent ID but not a teammate), set agentType to 'standalone'
|
||||
const agentType = isSwarmAgent
|
||||
? ('teammate' as const)
|
||||
: agentId
|
||||
? ('standalone' as const)
|
||||
: undefined
|
||||
if (agentId || agentType || parentSessionId || teamName) {
|
||||
return {
|
||||
...(agentId ? { agentId } : {}),
|
||||
...(agentType ? { agentType } : {}),
|
||||
...(parentSessionId ? { parentSessionId } : {}),
|
||||
...(teamName ? { teamName } : {}),
|
||||
}
|
||||
}
|
||||
|
||||
// Check bootstrap state for parent session ID (e.g., plan mode -> implementation)
|
||||
const stateParentSessionId = getParentSessionIdFromState()
|
||||
if (stateParentSessionId) {
|
||||
return { parentSessionId: stateParentSessionId }
|
||||
}
|
||||
|
||||
return {}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract base version from full version string. "2.0.36-dev.20251107.t174150.sha2709699" → "2.0.36-dev"
|
||||
*/
|
||||
const getVersionBase = memoize((): string | undefined => {
|
||||
const match = MACRO.VERSION.match(/^\d+\.\d+\.\d+(?:-[a-z]+)?/)
|
||||
return match ? match[0] : undefined
|
||||
})
|
||||
|
||||
/**
|
||||
* Builds the environment context object
|
||||
*/
|
||||
const buildEnvContext = memoize(async (): Promise<EnvContext> => {
|
||||
const [packageManagers, runtimes, linuxDistroInfo, vcs] = await Promise.all([
|
||||
env.getPackageManagers(),
|
||||
env.getRuntimes(),
|
||||
getLinuxDistroInfo(),
|
||||
detectVcs(),
|
||||
])
|
||||
|
||||
return {
|
||||
platform: getHostPlatformForAnalytics(),
|
||||
// Raw process.platform so freebsd/openbsd/aix/sunos are visible in BQ.
|
||||
// getHostPlatformForAnalytics() buckets those into 'linux'; here we want
|
||||
// the truth. CLAUDE_CODE_HOST_PLATFORM still overrides for container/remote.
|
||||
platformRaw: process.env.CLAUDE_CODE_HOST_PLATFORM || process.platform,
|
||||
arch: env.arch,
|
||||
nodeVersion: env.nodeVersion,
|
||||
terminal: envDynamic.terminal,
|
||||
packageManagers: packageManagers.join(','),
|
||||
runtimes: runtimes.join(','),
|
||||
isRunningWithBun: env.isRunningWithBun(),
|
||||
isCi: isEnvTruthy(process.env.CI),
|
||||
isClaubbit: isEnvTruthy(process.env.CLAUBBIT),
|
||||
isClaudeCodeRemote: isEnvTruthy(process.env.CLAUDE_CODE_REMOTE),
|
||||
isLocalAgentMode: process.env.CLAUDE_CODE_ENTRYPOINT === 'local-agent',
|
||||
isConductor: env.isConductor(),
|
||||
...(process.env.CLAUDE_CODE_REMOTE_ENVIRONMENT_TYPE && {
|
||||
remoteEnvironmentType: process.env.CLAUDE_CODE_REMOTE_ENVIRONMENT_TYPE,
|
||||
}),
|
||||
// Gated by feature flag to prevent leaking "coworkerType" string in external builds
|
||||
...(feature('COWORKER_TYPE_TELEMETRY')
|
||||
? process.env.CLAUDE_CODE_COWORKER_TYPE
|
||||
? { coworkerType: process.env.CLAUDE_CODE_COWORKER_TYPE }
|
||||
: {}
|
||||
: {}),
|
||||
...(process.env.CLAUDE_CODE_CONTAINER_ID && {
|
||||
claudeCodeContainerId: process.env.CLAUDE_CODE_CONTAINER_ID,
|
||||
}),
|
||||
...(process.env.CLAUDE_CODE_REMOTE_SESSION_ID && {
|
||||
claudeCodeRemoteSessionId: process.env.CLAUDE_CODE_REMOTE_SESSION_ID,
|
||||
}),
|
||||
...(process.env.CLAUDE_CODE_TAGS && {
|
||||
tags: process.env.CLAUDE_CODE_TAGS,
|
||||
}),
|
||||
isGithubAction: isEnvTruthy(process.env.GITHUB_ACTIONS),
|
||||
isClaudeCodeAction: isEnvTruthy(process.env.CLAUDE_CODE_ACTION),
|
||||
isClaudeAiAuth: isClaudeAISubscriber(),
|
||||
version: MACRO.VERSION,
|
||||
versionBase: getVersionBase(),
|
||||
buildTime: MACRO.BUILD_TIME,
|
||||
deploymentEnvironment: env.detectDeploymentEnvironment(),
|
||||
...(isEnvTruthy(process.env.GITHUB_ACTIONS) && {
|
||||
githubEventName: process.env.GITHUB_EVENT_NAME,
|
||||
githubActionsRunnerEnvironment: process.env.RUNNER_ENVIRONMENT,
|
||||
githubActionsRunnerOs: process.env.RUNNER_OS,
|
||||
githubActionRef: process.env.GITHUB_ACTION_PATH?.includes(
|
||||
'claude-code-action/',
|
||||
)
|
||||
? process.env.GITHUB_ACTION_PATH.split('claude-code-action/')[1]
|
||||
: undefined,
|
||||
}),
|
||||
...(getWslVersion() && { wslVersion: getWslVersion() }),
|
||||
...(linuxDistroInfo ?? {}),
|
||||
...(vcs.length > 0 ? { vcs: vcs.join(',') } : {}),
|
||||
}
|
||||
})
|
||||
|
||||
// --
|
||||
// CPU% delta tracking — inherently process-global, same pattern as logBatch/flushTimer in datadog.ts
|
||||
let prevCpuUsage: NodeJS.CpuUsage | null = null
|
||||
let prevWallTimeMs: number | null = null
|
||||
|
||||
/**
|
||||
* Builds process metrics object for all users.
|
||||
*/
|
||||
function buildProcessMetrics(): ProcessMetrics | undefined {
|
||||
try {
|
||||
const mem = process.memoryUsage()
|
||||
const cpu = process.cpuUsage()
|
||||
const now = Date.now()
|
||||
|
||||
let cpuPercent: number | undefined
|
||||
if (prevCpuUsage && prevWallTimeMs) {
|
||||
const wallDeltaMs = now - prevWallTimeMs
|
||||
if (wallDeltaMs > 0) {
|
||||
const userDeltaUs = cpu.user - prevCpuUsage.user
|
||||
const systemDeltaUs = cpu.system - prevCpuUsage.system
|
||||
cpuPercent =
|
||||
((userDeltaUs + systemDeltaUs) / (wallDeltaMs * 1000)) * 100
|
||||
}
|
||||
}
|
||||
prevCpuUsage = cpu
|
||||
prevWallTimeMs = now
|
||||
|
||||
return {
|
||||
uptime: process.uptime(),
|
||||
rss: mem.rss,
|
||||
heapTotal: mem.heapTotal,
|
||||
heapUsed: mem.heapUsed,
|
||||
external: mem.external,
|
||||
arrayBuffers: mem.arrayBuffers,
|
||||
// eslint-disable-next-line eslint-plugin-n/no-unsupported-features/node-builtins
|
||||
constrainedMemory: process.constrainedMemory(),
|
||||
cpuUsage: cpu,
|
||||
cpuPercent,
|
||||
}
|
||||
} catch {
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get core event metadata shared across all analytics systems.
|
||||
*
|
||||
* This function collects environment, runtime, and context information
|
||||
* that should be included with all analytics events.
|
||||
*
|
||||
* @param options - Configuration options
|
||||
* @returns Promise resolving to enriched metadata object
|
||||
*/
|
||||
export async function getEventMetadata(
|
||||
options: EnrichMetadataOptions = {},
|
||||
): Promise<EventMetadata> {
|
||||
const model = options.model ? String(options.model) : getMainLoopModel()
|
||||
const betas =
|
||||
typeof options.betas === 'string'
|
||||
? options.betas
|
||||
: getModelBetas(model).join(',')
|
||||
const [envContext, repoRemoteHash] = await Promise.all([
|
||||
buildEnvContext(),
|
||||
getRepoRemoteHash(),
|
||||
])
|
||||
const processMetrics = buildProcessMetrics()
|
||||
|
||||
const metadata: EventMetadata = {
|
||||
model,
|
||||
sessionId: getSessionId(),
|
||||
userType: process.env.USER_TYPE || '',
|
||||
...(betas.length > 0 ? { betas: betas } : {}),
|
||||
envContext,
|
||||
...(process.env.CLAUDE_CODE_ENTRYPOINT && {
|
||||
entrypoint: process.env.CLAUDE_CODE_ENTRYPOINT,
|
||||
}),
|
||||
...(process.env.CLAUDE_AGENT_SDK_VERSION && {
|
||||
agentSdkVersion: process.env.CLAUDE_AGENT_SDK_VERSION,
|
||||
}),
|
||||
isInteractive: String(getIsInteractive()),
|
||||
clientType: getClientType(),
|
||||
...(processMetrics && { processMetrics }),
|
||||
sweBenchRunId: process.env.SWE_BENCH_RUN_ID || '',
|
||||
sweBenchInstanceId: process.env.SWE_BENCH_INSTANCE_ID || '',
|
||||
sweBenchTaskId: process.env.SWE_BENCH_TASK_ID || '',
|
||||
// Swarm/team agent identification
|
||||
// Priority: AsyncLocalStorage context (subagents) > env vars (swarm teammates)
|
||||
...getAgentIdentification(),
|
||||
// Subscription tier for DAU-by-tier analytics
|
||||
...(getSubscriptionType() && {
|
||||
subscriptionType: getSubscriptionType()!,
|
||||
}),
|
||||
// Assistant mode tag — lives outside memoized buildEnvContext() because
|
||||
// setKairosActive() runs at main.tsx:~1648, after the first event may
|
||||
// have already fired and memoized the env. Read fresh per-event instead.
|
||||
...(feature('KAIROS') && getKairosActive()
|
||||
? { kairosActive: true as const }
|
||||
: {}),
|
||||
// Repo remote hash for joining with server-side repo bundle data
|
||||
...(repoRemoteHash && { rh: repoRemoteHash }),
|
||||
}
|
||||
|
||||
return metadata
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Core event metadata for 1P event logging (snake_case format).
|
||||
*/
|
||||
export type FirstPartyEventLoggingCoreMetadata = {
|
||||
session_id: string
|
||||
model: string
|
||||
user_type: string
|
||||
betas?: string
|
||||
entrypoint?: string
|
||||
agent_sdk_version?: string
|
||||
is_interactive: boolean
|
||||
client_type: string
|
||||
swe_bench_run_id?: string
|
||||
swe_bench_instance_id?: string
|
||||
swe_bench_task_id?: string
|
||||
// Swarm/team agent identification
|
||||
agent_id?: string
|
||||
parent_session_id?: string
|
||||
agent_type?: 'teammate' | 'subagent' | 'standalone'
|
||||
team_name?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Complete event logging metadata format for 1P events.
|
||||
*/
|
||||
export type FirstPartyEventLoggingMetadata = {
|
||||
env: EnvironmentMetadata
|
||||
process?: string
|
||||
// auth is a top-level field on ClaudeCodeInternalEvent (proto PublicApiAuth).
|
||||
// account_id is intentionally omitted — only UUID fields are populated client-side.
|
||||
auth?: PublicApiAuth
|
||||
// core fields correspond to the top level of ClaudeCodeInternalEvent.
|
||||
// They get directly exported to their individual columns in the BigQuery tables
|
||||
core: FirstPartyEventLoggingCoreMetadata
|
||||
// additional fields are populated in the additional_metadata field of the
|
||||
// ClaudeCodeInternalEvent proto. Includes but is not limited to information
|
||||
// that differs by event type.
|
||||
additional: Record<string, unknown>
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert metadata to 1P event logging format (snake_case fields).
|
||||
*
|
||||
* The /api/event_logging/batch endpoint expects snake_case field names
|
||||
* for environment and core metadata.
|
||||
*
|
||||
* @param metadata - Core event metadata
|
||||
* @param additionalMetadata - Additional metadata to include
|
||||
* @returns Metadata formatted for 1P event logging
|
||||
*/
|
||||
export function to1PEventFormat(
|
||||
metadata: EventMetadata,
|
||||
userMetadata: CoreUserData,
|
||||
additionalMetadata: Record<string, unknown> = {},
|
||||
): FirstPartyEventLoggingMetadata {
|
||||
const {
|
||||
envContext,
|
||||
processMetrics,
|
||||
rh,
|
||||
kairosActive,
|
||||
skillMode,
|
||||
observerMode,
|
||||
...coreFields
|
||||
} = metadata
|
||||
|
||||
// Convert envContext to snake_case.
|
||||
// IMPORTANT: env is typed as the proto-generated EnvironmentMetadata so that
|
||||
// adding a field here that the proto doesn't define is a compile error. The
|
||||
// generated toJSON() serializer silently drops unknown keys — a hand-written
|
||||
// parallel type previously let #11318, #13924, #19448, and coworker_type all
|
||||
// ship fields that never reached BQ.
|
||||
// Adding a field? Update the monorepo proto first (go/cc-logging):
|
||||
// event_schemas/.../claude_code/v1/claude_code_internal_event.proto
|
||||
// then run `bun run generate:proto` here.
|
||||
const env: EnvironmentMetadata = {
|
||||
platform: envContext.platform,
|
||||
platform_raw: envContext.platformRaw,
|
||||
arch: envContext.arch,
|
||||
node_version: envContext.nodeVersion,
|
||||
terminal: envContext.terminal || 'unknown',
|
||||
package_managers: envContext.packageManagers,
|
||||
runtimes: envContext.runtimes,
|
||||
is_running_with_bun: envContext.isRunningWithBun,
|
||||
is_ci: envContext.isCi,
|
||||
is_claubbit: envContext.isClaubbit,
|
||||
is_claude_code_remote: envContext.isClaudeCodeRemote,
|
||||
is_local_agent_mode: envContext.isLocalAgentMode,
|
||||
is_conductor: envContext.isConductor,
|
||||
is_github_action: envContext.isGithubAction,
|
||||
is_claude_code_action: envContext.isClaudeCodeAction,
|
||||
is_claude_ai_auth: envContext.isClaudeAiAuth,
|
||||
version: envContext.version,
|
||||
build_time: envContext.buildTime,
|
||||
deployment_environment: envContext.deploymentEnvironment,
|
||||
}
|
||||
|
||||
// Add optional env fields
|
||||
if (envContext.remoteEnvironmentType) {
|
||||
env.remote_environment_type = envContext.remoteEnvironmentType
|
||||
}
|
||||
if (feature('COWORKER_TYPE_TELEMETRY') && envContext.coworkerType) {
|
||||
env.coworker_type = envContext.coworkerType
|
||||
}
|
||||
if (envContext.claudeCodeContainerId) {
|
||||
env.claude_code_container_id = envContext.claudeCodeContainerId
|
||||
}
|
||||
if (envContext.claudeCodeRemoteSessionId) {
|
||||
env.claude_code_remote_session_id = envContext.claudeCodeRemoteSessionId
|
||||
}
|
||||
if (envContext.tags) {
|
||||
env.tags = envContext.tags
|
||||
.split(',')
|
||||
.map(t => t.trim())
|
||||
.filter(Boolean)
|
||||
}
|
||||
if (envContext.githubEventName) {
|
||||
env.github_event_name = envContext.githubEventName
|
||||
}
|
||||
if (envContext.githubActionsRunnerEnvironment) {
|
||||
env.github_actions_runner_environment =
|
||||
envContext.githubActionsRunnerEnvironment
|
||||
}
|
||||
if (envContext.githubActionsRunnerOs) {
|
||||
env.github_actions_runner_os = envContext.githubActionsRunnerOs
|
||||
}
|
||||
if (envContext.githubActionRef) {
|
||||
env.github_action_ref = envContext.githubActionRef
|
||||
}
|
||||
if (envContext.wslVersion) {
|
||||
env.wsl_version = envContext.wslVersion
|
||||
}
|
||||
if (envContext.linuxDistroId) {
|
||||
env.linux_distro_id = envContext.linuxDistroId
|
||||
}
|
||||
if (envContext.linuxDistroVersion) {
|
||||
env.linux_distro_version = envContext.linuxDistroVersion
|
||||
}
|
||||
if (envContext.linuxKernel) {
|
||||
env.linux_kernel = envContext.linuxKernel
|
||||
}
|
||||
if (envContext.vcs) {
|
||||
env.vcs = envContext.vcs
|
||||
}
|
||||
if (envContext.versionBase) {
|
||||
env.version_base = envContext.versionBase
|
||||
}
|
||||
|
||||
// Convert core fields to snake_case
|
||||
const core: FirstPartyEventLoggingCoreMetadata = {
|
||||
session_id: coreFields.sessionId,
|
||||
model: coreFields.model,
|
||||
user_type: coreFields.userType,
|
||||
is_interactive: coreFields.isInteractive === 'true',
|
||||
client_type: coreFields.clientType,
|
||||
}
|
||||
|
||||
// Add other core fields
|
||||
if (coreFields.betas) {
|
||||
core.betas = coreFields.betas
|
||||
}
|
||||
if (coreFields.entrypoint) {
|
||||
core.entrypoint = coreFields.entrypoint
|
||||
}
|
||||
if (coreFields.agentSdkVersion) {
|
||||
core.agent_sdk_version = coreFields.agentSdkVersion
|
||||
}
|
||||
if (coreFields.sweBenchRunId) {
|
||||
core.swe_bench_run_id = coreFields.sweBenchRunId
|
||||
}
|
||||
if (coreFields.sweBenchInstanceId) {
|
||||
core.swe_bench_instance_id = coreFields.sweBenchInstanceId
|
||||
}
|
||||
if (coreFields.sweBenchTaskId) {
|
||||
core.swe_bench_task_id = coreFields.sweBenchTaskId
|
||||
}
|
||||
// Swarm/team agent identification
|
||||
if (coreFields.agentId) {
|
||||
core.agent_id = coreFields.agentId
|
||||
}
|
||||
if (coreFields.parentSessionId) {
|
||||
core.parent_session_id = coreFields.parentSessionId
|
||||
}
|
||||
if (coreFields.agentType) {
|
||||
core.agent_type = coreFields.agentType
|
||||
}
|
||||
if (coreFields.teamName) {
|
||||
core.team_name = coreFields.teamName
|
||||
}
|
||||
|
||||
// Map userMetadata to output fields.
|
||||
// Based on src/utils/user.ts getUser(), but with fields present in other
|
||||
// parts of ClaudeCodeInternalEvent deduplicated.
|
||||
// Convert camelCase GitHubActionsMetadata to snake_case for 1P API
|
||||
// Note: github_actions_metadata is placed inside env (EnvironmentMetadata)
|
||||
// rather than at the top level of ClaudeCodeInternalEvent
|
||||
if (userMetadata.githubActionsMetadata) {
|
||||
const ghMeta = userMetadata.githubActionsMetadata
|
||||
env.github_actions_metadata = {
|
||||
actor_id: ghMeta.actorId,
|
||||
repository_id: ghMeta.repositoryId,
|
||||
repository_owner_id: ghMeta.repositoryOwnerId,
|
||||
}
|
||||
}
|
||||
|
||||
let auth: PublicApiAuth | undefined
|
||||
if (userMetadata.accountUuid || userMetadata.organizationUuid) {
|
||||
auth = {
|
||||
account_uuid: userMetadata.accountUuid,
|
||||
organization_uuid: userMetadata.organizationUuid,
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
env,
|
||||
...(processMetrics && {
|
||||
process: Buffer.from(jsonStringify(processMetrics)).toString('base64'),
|
||||
}),
|
||||
...(auth && { auth }),
|
||||
core,
|
||||
additional: {
|
||||
...(rh && { rh }),
|
||||
...(kairosActive && { is_assistant_mode: true }),
|
||||
...(skillMode && { skill_mode: skillMode }),
|
||||
...(observerMode && { observer_mode: observerMode }),
|
||||
...additionalMetadata,
|
||||
},
|
||||
}
|
||||
}
|
||||
114
src/services/analytics/sink.ts
Normal file
114
src/services/analytics/sink.ts
Normal file
@@ -0,0 +1,114 @@
|
||||
/**
|
||||
* Analytics sink implementation
|
||||
*
|
||||
* This module contains the actual analytics routing logic and should be
|
||||
* initialized during app startup. It routes events to Datadog and 1P event
|
||||
* logging.
|
||||
*
|
||||
* Usage: Call initializeAnalyticsSink() during app startup to attach the sink.
|
||||
*/
|
||||
|
||||
import { trackDatadogEvent } from './datadog.js'
|
||||
import { logEventTo1P, shouldSampleEvent } from './firstPartyEventLogger.js'
|
||||
import { checkStatsigFeatureGate_CACHED_MAY_BE_STALE } from './growthbook.js'
|
||||
import { attachAnalyticsSink, stripProtoFields } from './index.js'
|
||||
import { isSinkKilled } from './sinkKillswitch.js'
|
||||
|
||||
// Local type matching the logEvent metadata signature
|
||||
type LogEventMetadata = { [key: string]: boolean | number | undefined }
|
||||
|
||||
const DATADOG_GATE_NAME = 'tengu_log_datadog_events'
|
||||
|
||||
// Module-level gate state - starts undefined, initialized during startup
|
||||
let isDatadogGateEnabled: boolean | undefined = undefined
|
||||
|
||||
/**
|
||||
* Check if Datadog tracking is enabled.
|
||||
* Falls back to cached value from previous session if not yet initialized.
|
||||
*/
|
||||
function shouldTrackDatadog(): boolean {
|
||||
if (isSinkKilled('datadog')) {
|
||||
return false
|
||||
}
|
||||
if (isDatadogGateEnabled !== undefined) {
|
||||
return isDatadogGateEnabled
|
||||
}
|
||||
|
||||
// Fallback to cached value from previous session
|
||||
try {
|
||||
return checkStatsigFeatureGate_CACHED_MAY_BE_STALE(DATADOG_GATE_NAME)
|
||||
} catch {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Log an event (synchronous implementation)
|
||||
*/
|
||||
function logEventImpl(eventName: string, metadata: LogEventMetadata): void {
|
||||
// Check if this event should be sampled
|
||||
const sampleResult = shouldSampleEvent(eventName)
|
||||
|
||||
// If sample result is 0, the event was not selected for logging
|
||||
if (sampleResult === 0) {
|
||||
return
|
||||
}
|
||||
|
||||
// If sample result is a positive number, add it to metadata
|
||||
const metadataWithSampleRate =
|
||||
sampleResult !== null
|
||||
? { ...metadata, sample_rate: sampleResult }
|
||||
: metadata
|
||||
|
||||
if (shouldTrackDatadog()) {
|
||||
// Datadog is a general-access backend — strip _PROTO_* keys
|
||||
// (unredacted PII-tagged values meant only for the 1P privileged column).
|
||||
void trackDatadogEvent(eventName, stripProtoFields(metadataWithSampleRate))
|
||||
}
|
||||
|
||||
// 1P receives the full payload including _PROTO_* — the exporter
|
||||
// destructures and routes those keys to proto fields itself.
|
||||
logEventTo1P(eventName, metadataWithSampleRate)
|
||||
}
|
||||
|
||||
/**
|
||||
* Log an event (asynchronous implementation)
|
||||
*
|
||||
* With Segment removed the two remaining sinks are fire-and-forget, so this
|
||||
* just wraps the sync impl — kept to preserve the sink interface contract.
|
||||
*/
|
||||
function logEventAsyncImpl(
|
||||
eventName: string,
|
||||
metadata: LogEventMetadata,
|
||||
): Promise<void> {
|
||||
logEventImpl(eventName, metadata)
|
||||
return Promise.resolve()
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize analytics gates during startup.
|
||||
*
|
||||
* Updates gate values from server. Early events use cached values from previous
|
||||
* session to avoid data loss during initialization.
|
||||
*
|
||||
* Called from main.tsx during setupBackend().
|
||||
*/
|
||||
export function initializeAnalyticsGates(): void {
|
||||
isDatadogGateEnabled =
|
||||
checkStatsigFeatureGate_CACHED_MAY_BE_STALE(DATADOG_GATE_NAME)
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize the analytics sink.
|
||||
*
|
||||
* Call this during app startup to attach the analytics backend.
|
||||
* Any events logged before this is called will be queued and drained.
|
||||
*
|
||||
* Idempotent: safe to call multiple times (subsequent calls are no-ops).
|
||||
*/
|
||||
export function initializeAnalyticsSink(): void {
|
||||
attachAnalyticsSink({
|
||||
logEvent: logEventImpl,
|
||||
logEventAsync: logEventAsyncImpl,
|
||||
})
|
||||
}
|
||||
25
src/services/analytics/sinkKillswitch.ts
Normal file
25
src/services/analytics/sinkKillswitch.ts
Normal file
@@ -0,0 +1,25 @@
|
||||
import { getDynamicConfig_CACHED_MAY_BE_STALE } from './growthbook.js'
|
||||
|
||||
// Mangled name: per-sink analytics killswitch
|
||||
const SINK_KILLSWITCH_CONFIG_NAME = 'tengu_frond_boric'
|
||||
|
||||
export type SinkName = 'datadog' | 'firstParty'
|
||||
|
||||
/**
|
||||
* GrowthBook JSON config that disables individual analytics sinks.
|
||||
* Shape: { datadog?: boolean, firstParty?: boolean }
|
||||
* A value of true for a key stops all dispatch to that sink.
|
||||
* Default {} (nothing killed). Fail-open: missing/malformed config = sink stays on.
|
||||
*
|
||||
* NOTE: Must NOT be called from inside is1PEventLoggingEnabled() -
|
||||
* growthbook.ts:isGrowthBookEnabled() calls that, so a lookup here would recurse.
|
||||
* Call at per-event dispatch sites instead.
|
||||
*/
|
||||
export function isSinkKilled(sink: SinkName): boolean {
|
||||
const config = getDynamicConfig_CACHED_MAY_BE_STALE<
|
||||
Partial<Record<SinkName, boolean>>
|
||||
>(SINK_KILLSWITCH_CONFIG_NAME, {})
|
||||
// getFeatureValue_CACHED_MAY_BE_STALE guards on `!== undefined`, so a
|
||||
// cached JSON null leaks through instead of falling back to {}.
|
||||
return config?.[sink] === true
|
||||
}
|
||||
119
src/services/api/adminRequests.ts
Normal file
119
src/services/api/adminRequests.ts
Normal file
@@ -0,0 +1,119 @@
|
||||
import axios from 'axios'
|
||||
import { getOauthConfig } from '../../constants/oauth.js'
|
||||
import { getOAuthHeaders, prepareApiRequest } from '../../utils/teleport/api.js'
|
||||
|
||||
export type AdminRequestType = 'limit_increase' | 'seat_upgrade'
|
||||
|
||||
export type AdminRequestStatus = 'pending' | 'approved' | 'dismissed'
|
||||
|
||||
export type AdminRequestSeatUpgradeDetails = {
|
||||
message?: string | null
|
||||
current_seat_tier?: string | null
|
||||
}
|
||||
|
||||
export type AdminRequestCreateParams =
|
||||
| {
|
||||
request_type: 'limit_increase'
|
||||
details: null
|
||||
}
|
||||
| {
|
||||
request_type: 'seat_upgrade'
|
||||
details: AdminRequestSeatUpgradeDetails
|
||||
}
|
||||
|
||||
export type AdminRequest = {
|
||||
uuid: string
|
||||
status: AdminRequestStatus
|
||||
requester_uuid?: string | null
|
||||
created_at: string
|
||||
} & (
|
||||
| {
|
||||
request_type: 'limit_increase'
|
||||
details: null
|
||||
}
|
||||
| {
|
||||
request_type: 'seat_upgrade'
|
||||
details: AdminRequestSeatUpgradeDetails
|
||||
}
|
||||
)
|
||||
|
||||
/**
|
||||
* Create an admin request (limit increase or seat upgrade).
|
||||
*
|
||||
* For Team/Enterprise users who don't have billing/admin permissions,
|
||||
* this creates a request that their admin can act on.
|
||||
*
|
||||
* If a pending request of the same type already exists for this user,
|
||||
* returns the existing request instead of creating a new one.
|
||||
*/
|
||||
export async function createAdminRequest(
|
||||
params: AdminRequestCreateParams,
|
||||
): Promise<AdminRequest> {
|
||||
const { accessToken, orgUUID } = await prepareApiRequest()
|
||||
|
||||
const headers = {
|
||||
...getOAuthHeaders(accessToken),
|
||||
'x-organization-uuid': orgUUID,
|
||||
}
|
||||
|
||||
const url = `${getOauthConfig().BASE_API_URL}/api/oauth/organizations/${orgUUID}/admin_requests`
|
||||
|
||||
const response = await axios.post<AdminRequest>(url, params, { headers })
|
||||
|
||||
return response.data
|
||||
}
|
||||
|
||||
/**
|
||||
* Get pending admin request of a specific type for the current user.
|
||||
*
|
||||
* Returns the pending request if one exists, otherwise null.
|
||||
*/
|
||||
export async function getMyAdminRequests(
|
||||
requestType: AdminRequestType,
|
||||
statuses: AdminRequestStatus[],
|
||||
): Promise<AdminRequest[] | null> {
|
||||
const { accessToken, orgUUID } = await prepareApiRequest()
|
||||
|
||||
const headers = {
|
||||
...getOAuthHeaders(accessToken),
|
||||
'x-organization-uuid': orgUUID,
|
||||
}
|
||||
|
||||
let url = `${getOauthConfig().BASE_API_URL}/api/oauth/organizations/${orgUUID}/admin_requests/me?request_type=${requestType}`
|
||||
for (const status of statuses) {
|
||||
url += `&statuses=${status}`
|
||||
}
|
||||
|
||||
const response = await axios.get<AdminRequest[] | null>(url, {
|
||||
headers,
|
||||
})
|
||||
|
||||
return response.data
|
||||
}
|
||||
|
||||
type AdminRequestEligibilityResponse = {
|
||||
request_type: AdminRequestType
|
||||
is_allowed: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a specific admin request type is allowed for this org.
|
||||
*/
|
||||
export async function checkAdminRequestEligibility(
|
||||
requestType: AdminRequestType,
|
||||
): Promise<AdminRequestEligibilityResponse | null> {
|
||||
const { accessToken, orgUUID } = await prepareApiRequest()
|
||||
|
||||
const headers = {
|
||||
...getOAuthHeaders(accessToken),
|
||||
'x-organization-uuid': orgUUID,
|
||||
}
|
||||
|
||||
const url = `${getOauthConfig().BASE_API_URL}/api/oauth/organizations/${orgUUID}/admin_requests/eligibility?request_type=${requestType}`
|
||||
|
||||
const response = await axios.get<AdminRequestEligibilityResponse>(url, {
|
||||
headers,
|
||||
})
|
||||
|
||||
return response.data
|
||||
}
|
||||
141
src/services/api/bootstrap.ts
Normal file
141
src/services/api/bootstrap.ts
Normal file
@@ -0,0 +1,141 @@
|
||||
import axios from 'axios'
|
||||
import isEqual from 'lodash-es/isEqual.js'
|
||||
import {
|
||||
getAnthropicApiKey,
|
||||
getClaudeAIOAuthTokens,
|
||||
hasProfileScope,
|
||||
} from 'src/utils/auth.js'
|
||||
import { z } from 'zod'
|
||||
import { getOauthConfig, OAUTH_BETA_HEADER } from '../../constants/oauth.js'
|
||||
import { getGlobalConfig, saveGlobalConfig } from '../../utils/config.js'
|
||||
import { logForDebugging } from '../../utils/debug.js'
|
||||
import { withOAuth401Retry } from '../../utils/http.js'
|
||||
import { lazySchema } from '../../utils/lazySchema.js'
|
||||
import { logError } from '../../utils/log.js'
|
||||
import { getAPIProvider } from '../../utils/model/providers.js'
|
||||
import { isEssentialTrafficOnly } from '../../utils/privacyLevel.js'
|
||||
import { getClaudeCodeUserAgent } from '../../utils/userAgent.js'
|
||||
|
||||
const bootstrapResponseSchema = lazySchema(() =>
|
||||
z.object({
|
||||
client_data: z.record(z.unknown()).nullish(),
|
||||
additional_model_options: z
|
||||
.array(
|
||||
z
|
||||
.object({
|
||||
model: z.string(),
|
||||
name: z.string(),
|
||||
description: z.string(),
|
||||
})
|
||||
.transform(({ model, name, description }) => ({
|
||||
value: model,
|
||||
label: name,
|
||||
description,
|
||||
})),
|
||||
)
|
||||
.nullish(),
|
||||
}),
|
||||
)
|
||||
|
||||
type BootstrapResponse = z.infer<ReturnType<typeof bootstrapResponseSchema>>
|
||||
|
||||
async function fetchBootstrapAPI(): Promise<BootstrapResponse | null> {
|
||||
if (isEssentialTrafficOnly()) {
|
||||
logForDebugging('[Bootstrap] Skipped: Nonessential traffic disabled')
|
||||
return null
|
||||
}
|
||||
|
||||
if (getAPIProvider() !== 'firstParty') {
|
||||
logForDebugging('[Bootstrap] Skipped: 3P provider')
|
||||
return null
|
||||
}
|
||||
|
||||
// OAuth preferred (requires user:profile scope — service-key OAuth tokens
|
||||
// lack it and would 403). Fall back to API key auth for console users.
|
||||
const apiKey = getAnthropicApiKey()
|
||||
const hasUsableOAuth =
|
||||
getClaudeAIOAuthTokens()?.accessToken && hasProfileScope()
|
||||
if (!hasUsableOAuth && !apiKey) {
|
||||
logForDebugging('[Bootstrap] Skipped: no usable OAuth or API key')
|
||||
return null
|
||||
}
|
||||
|
||||
const endpoint = `${getOauthConfig().BASE_API_URL}/api/claude_cli/bootstrap`
|
||||
|
||||
// withOAuth401Retry handles the refresh-and-retry. API key users fail
|
||||
// through on 401 (no refresh mechanism — no OAuth token to pass).
|
||||
try {
|
||||
return await withOAuth401Retry(async () => {
|
||||
// Re-read OAuth each call so the retry picks up the refreshed token.
|
||||
const token = getClaudeAIOAuthTokens()?.accessToken
|
||||
let authHeaders: Record<string, string>
|
||||
if (token && hasProfileScope()) {
|
||||
authHeaders = {
|
||||
Authorization: `Bearer ${token}`,
|
||||
'anthropic-beta': OAUTH_BETA_HEADER,
|
||||
}
|
||||
} else if (apiKey) {
|
||||
authHeaders = { 'x-api-key': apiKey }
|
||||
} else {
|
||||
logForDebugging('[Bootstrap] No auth available on retry, aborting')
|
||||
return null
|
||||
}
|
||||
|
||||
logForDebugging('[Bootstrap] Fetching')
|
||||
const response = await axios.get<unknown>(endpoint, {
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'User-Agent': getClaudeCodeUserAgent(),
|
||||
...authHeaders,
|
||||
},
|
||||
timeout: 5000,
|
||||
})
|
||||
const parsed = bootstrapResponseSchema().safeParse(response.data)
|
||||
if (!parsed.success) {
|
||||
logForDebugging(
|
||||
`[Bootstrap] Response failed validation: ${parsed.error.message}`,
|
||||
)
|
||||
return null
|
||||
}
|
||||
logForDebugging('[Bootstrap] Fetch ok')
|
||||
return parsed.data
|
||||
})
|
||||
} catch (error) {
|
||||
logForDebugging(
|
||||
`[Bootstrap] Fetch failed: ${axios.isAxiosError(error) ? (error.response?.status ?? error.code) : 'unknown'}`,
|
||||
)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch bootstrap data from the API and persist to disk cache.
|
||||
*/
|
||||
export async function fetchBootstrapData(): Promise<void> {
|
||||
try {
|
||||
const response = await fetchBootstrapAPI()
|
||||
if (!response) return
|
||||
|
||||
const clientData = response.client_data ?? null
|
||||
const additionalModelOptions = response.additional_model_options ?? []
|
||||
|
||||
// Only persist if data actually changed — avoids a config write on every startup.
|
||||
const config = getGlobalConfig()
|
||||
if (
|
||||
isEqual(config.clientDataCache, clientData) &&
|
||||
isEqual(config.additionalModelOptionsCache, additionalModelOptions)
|
||||
) {
|
||||
logForDebugging('[Bootstrap] Cache unchanged, skipping write')
|
||||
return
|
||||
}
|
||||
|
||||
logForDebugging('[Bootstrap] Cache updated, persisting to disk')
|
||||
saveGlobalConfig(current => ({
|
||||
...current,
|
||||
clientDataCache: clientData,
|
||||
additionalModelOptionsCache: additionalModelOptions,
|
||||
}))
|
||||
} catch (error) {
|
||||
logError(error)
|
||||
}
|
||||
}
|
||||
3419
src/services/api/claude.ts
Normal file
3419
src/services/api/claude.ts
Normal file
File diff suppressed because it is too large
Load Diff
389
src/services/api/client.ts
Normal file
389
src/services/api/client.ts
Normal file
@@ -0,0 +1,389 @@
|
||||
import Anthropic, { type ClientOptions } from '@anthropic-ai/sdk'
|
||||
import { randomUUID } from 'crypto'
|
||||
import type { GoogleAuth } from 'google-auth-library'
|
||||
import {
|
||||
checkAndRefreshOAuthTokenIfNeeded,
|
||||
getAnthropicApiKey,
|
||||
getApiKeyFromApiKeyHelper,
|
||||
getClaudeAIOAuthTokens,
|
||||
isClaudeAISubscriber,
|
||||
refreshAndGetAwsCredentials,
|
||||
refreshGcpCredentialsIfNeeded,
|
||||
} from 'src/utils/auth.js'
|
||||
import { getUserAgent } from 'src/utils/http.js'
|
||||
import { getSmallFastModel } from 'src/utils/model/model.js'
|
||||
import {
|
||||
getAPIProvider,
|
||||
isFirstPartyAnthropicBaseUrl,
|
||||
} from 'src/utils/model/providers.js'
|
||||
import { getProxyFetchOptions } from 'src/utils/proxy.js'
|
||||
import {
|
||||
getIsNonInteractiveSession,
|
||||
getSessionId,
|
||||
} from '../../bootstrap/state.js'
|
||||
import { getOauthConfig } from '../../constants/oauth.js'
|
||||
import { isDebugToStdErr, logForDebugging } from '../../utils/debug.js'
|
||||
import {
|
||||
getAWSRegion,
|
||||
getVertexRegionForModel,
|
||||
isEnvTruthy,
|
||||
} from '../../utils/envUtils.js'
|
||||
|
||||
/**
|
||||
* Environment variables for different client types:
|
||||
*
|
||||
* Direct API:
|
||||
* - ANTHROPIC_API_KEY: Required for direct API access
|
||||
*
|
||||
* AWS Bedrock:
|
||||
* - AWS credentials configured via aws-sdk defaults
|
||||
* - AWS_REGION or AWS_DEFAULT_REGION: Sets the AWS region for all models (default: us-east-1)
|
||||
* - ANTHROPIC_SMALL_FAST_MODEL_AWS_REGION: Optional. Override AWS region specifically for the small fast model (Haiku)
|
||||
*
|
||||
* Foundry (Azure):
|
||||
* - ANTHROPIC_FOUNDRY_RESOURCE: Your Azure resource name (e.g., 'my-resource')
|
||||
* For the full endpoint: https://{resource}.services.ai.azure.com/anthropic/v1/messages
|
||||
* - ANTHROPIC_FOUNDRY_BASE_URL: Optional. Alternative to resource - provide full base URL directly
|
||||
* (e.g., 'https://my-resource.services.ai.azure.com')
|
||||
*
|
||||
* Authentication (one of the following):
|
||||
* - ANTHROPIC_FOUNDRY_API_KEY: Your Microsoft Foundry API key (if using API key auth)
|
||||
* - Azure AD authentication: If no API key is provided, uses DefaultAzureCredential
|
||||
* which supports multiple auth methods (environment variables, managed identity,
|
||||
* Azure CLI, etc.). See: https://docs.microsoft.com/en-us/javascript/api/@azure/identity
|
||||
*
|
||||
* Vertex AI:
|
||||
* - Model-specific region variables (highest priority):
|
||||
* - VERTEX_REGION_CLAUDE_3_5_HAIKU: Region for Claude 3.5 Haiku model
|
||||
* - VERTEX_REGION_CLAUDE_HAIKU_4_5: Region for Claude Haiku 4.5 model
|
||||
* - VERTEX_REGION_CLAUDE_3_5_SONNET: Region for Claude 3.5 Sonnet model
|
||||
* - VERTEX_REGION_CLAUDE_3_7_SONNET: Region for Claude 3.7 Sonnet model
|
||||
* - CLOUD_ML_REGION: Optional. The default GCP region to use for all models
|
||||
* If specific model region not specified above
|
||||
* - ANTHROPIC_VERTEX_PROJECT_ID: Required. Your GCP project ID
|
||||
* - Standard GCP credentials configured via google-auth-library
|
||||
*
|
||||
* Priority for determining region:
|
||||
* 1. Hardcoded model-specific environment variables
|
||||
* 2. Global CLOUD_ML_REGION variable
|
||||
* 3. Default region from config
|
||||
* 4. Fallback region (us-east5)
|
||||
*/
|
||||
|
||||
function createStderrLogger(): ClientOptions['logger'] {
|
||||
return {
|
||||
error: (msg, ...args) =>
|
||||
// biome-ignore lint/suspicious/noConsole:: intentional console output -- SDK logger must use console
|
||||
console.error('[Anthropic SDK ERROR]', msg, ...args),
|
||||
// biome-ignore lint/suspicious/noConsole:: intentional console output -- SDK logger must use console
|
||||
warn: (msg, ...args) => console.error('[Anthropic SDK WARN]', msg, ...args),
|
||||
// biome-ignore lint/suspicious/noConsole:: intentional console output -- SDK logger must use console
|
||||
info: (msg, ...args) => console.error('[Anthropic SDK INFO]', msg, ...args),
|
||||
debug: (msg, ...args) =>
|
||||
// biome-ignore lint/suspicious/noConsole:: intentional console output -- SDK logger must use console
|
||||
console.error('[Anthropic SDK DEBUG]', msg, ...args),
|
||||
}
|
||||
}
|
||||
|
||||
export async function getAnthropicClient({
|
||||
apiKey,
|
||||
maxRetries,
|
||||
model,
|
||||
fetchOverride,
|
||||
source,
|
||||
}: {
|
||||
apiKey?: string
|
||||
maxRetries: number
|
||||
model?: string
|
||||
fetchOverride?: ClientOptions['fetch']
|
||||
source?: string
|
||||
}): Promise<Anthropic> {
|
||||
const containerId = process.env.CLAUDE_CODE_CONTAINER_ID
|
||||
const remoteSessionId = process.env.CLAUDE_CODE_REMOTE_SESSION_ID
|
||||
const clientApp = process.env.CLAUDE_AGENT_SDK_CLIENT_APP
|
||||
const customHeaders = getCustomHeaders()
|
||||
const defaultHeaders: { [key: string]: string } = {
|
||||
'x-app': 'cli',
|
||||
'User-Agent': getUserAgent(),
|
||||
'X-Claude-Code-Session-Id': getSessionId(),
|
||||
...customHeaders,
|
||||
...(containerId ? { 'x-claude-remote-container-id': containerId } : {}),
|
||||
...(remoteSessionId
|
||||
? { 'x-claude-remote-session-id': remoteSessionId }
|
||||
: {}),
|
||||
// SDK consumers can identify their app/library for backend analytics
|
||||
...(clientApp ? { 'x-client-app': clientApp } : {}),
|
||||
}
|
||||
|
||||
// Log API client configuration for HFI debugging
|
||||
logForDebugging(
|
||||
`[API:request] Creating client, ANTHROPIC_CUSTOM_HEADERS present: ${!!process.env.ANTHROPIC_CUSTOM_HEADERS}, has Authorization header: ${!!customHeaders['Authorization']}`,
|
||||
)
|
||||
|
||||
// Add additional protection header if enabled via env var
|
||||
const additionalProtectionEnabled = isEnvTruthy(
|
||||
process.env.CLAUDE_CODE_ADDITIONAL_PROTECTION,
|
||||
)
|
||||
if (additionalProtectionEnabled) {
|
||||
defaultHeaders['x-anthropic-additional-protection'] = 'true'
|
||||
}
|
||||
|
||||
logForDebugging('[API:auth] OAuth token check starting')
|
||||
await checkAndRefreshOAuthTokenIfNeeded()
|
||||
logForDebugging('[API:auth] OAuth token check complete')
|
||||
|
||||
if (!isClaudeAISubscriber()) {
|
||||
await configureApiKeyHeaders(defaultHeaders, getIsNonInteractiveSession())
|
||||
}
|
||||
|
||||
const resolvedFetch = buildFetch(fetchOverride, source)
|
||||
|
||||
const ARGS = {
|
||||
defaultHeaders,
|
||||
maxRetries,
|
||||
timeout: parseInt(process.env.API_TIMEOUT_MS || String(600 * 1000), 10),
|
||||
dangerouslyAllowBrowser: true,
|
||||
fetchOptions: getProxyFetchOptions({
|
||||
forAnthropicAPI: true,
|
||||
}) as ClientOptions['fetchOptions'],
|
||||
...(resolvedFetch && {
|
||||
fetch: resolvedFetch,
|
||||
}),
|
||||
}
|
||||
if (isEnvTruthy(process.env.CLAUDE_CODE_USE_BEDROCK)) {
|
||||
const { AnthropicBedrock } = await import('@anthropic-ai/bedrock-sdk')
|
||||
// Use region override for small fast model if specified
|
||||
const awsRegion =
|
||||
model === getSmallFastModel() &&
|
||||
process.env.ANTHROPIC_SMALL_FAST_MODEL_AWS_REGION
|
||||
? process.env.ANTHROPIC_SMALL_FAST_MODEL_AWS_REGION
|
||||
: getAWSRegion()
|
||||
|
||||
const bedrockArgs: ConstructorParameters<typeof AnthropicBedrock>[0] = {
|
||||
...ARGS,
|
||||
awsRegion,
|
||||
...(isEnvTruthy(process.env.CLAUDE_CODE_SKIP_BEDROCK_AUTH) && {
|
||||
skipAuth: true,
|
||||
}),
|
||||
...(isDebugToStdErr() && { logger: createStderrLogger() }),
|
||||
}
|
||||
|
||||
// Add API key authentication if available
|
||||
if (process.env.AWS_BEARER_TOKEN_BEDROCK) {
|
||||
bedrockArgs.skipAuth = true
|
||||
// Add the Bearer token for Bedrock API key authentication
|
||||
bedrockArgs.defaultHeaders = {
|
||||
...bedrockArgs.defaultHeaders,
|
||||
Authorization: `Bearer ${process.env.AWS_BEARER_TOKEN_BEDROCK}`,
|
||||
}
|
||||
} else if (!isEnvTruthy(process.env.CLAUDE_CODE_SKIP_BEDROCK_AUTH)) {
|
||||
// Refresh auth and get credentials with cache clearing
|
||||
const cachedCredentials = await refreshAndGetAwsCredentials()
|
||||
if (cachedCredentials) {
|
||||
bedrockArgs.awsAccessKey = cachedCredentials.accessKeyId
|
||||
bedrockArgs.awsSecretKey = cachedCredentials.secretAccessKey
|
||||
bedrockArgs.awsSessionToken = cachedCredentials.sessionToken
|
||||
}
|
||||
}
|
||||
// we have always been lying about the return type - this doesn't support batching or models
|
||||
return new AnthropicBedrock(bedrockArgs) as unknown as Anthropic
|
||||
}
|
||||
if (isEnvTruthy(process.env.CLAUDE_CODE_USE_FOUNDRY)) {
|
||||
const { AnthropicFoundry } = await import('@anthropic-ai/foundry-sdk')
|
||||
// Determine Azure AD token provider based on configuration
|
||||
// SDK reads ANTHROPIC_FOUNDRY_API_KEY by default
|
||||
let azureADTokenProvider: (() => Promise<string>) | undefined
|
||||
if (!process.env.ANTHROPIC_FOUNDRY_API_KEY) {
|
||||
if (isEnvTruthy(process.env.CLAUDE_CODE_SKIP_FOUNDRY_AUTH)) {
|
||||
// Mock token provider for testing/proxy scenarios (similar to Vertex mock GoogleAuth)
|
||||
azureADTokenProvider = () => Promise.resolve('')
|
||||
} else {
|
||||
// Use real Azure AD authentication with DefaultAzureCredential
|
||||
const {
|
||||
DefaultAzureCredential: AzureCredential,
|
||||
getBearerTokenProvider,
|
||||
} = await import('@azure/identity')
|
||||
azureADTokenProvider = getBearerTokenProvider(
|
||||
new AzureCredential(),
|
||||
'https://cognitiveservices.azure.com/.default',
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
const foundryArgs: ConstructorParameters<typeof AnthropicFoundry>[0] = {
|
||||
...ARGS,
|
||||
...(azureADTokenProvider && { azureADTokenProvider }),
|
||||
...(isDebugToStdErr() && { logger: createStderrLogger() }),
|
||||
}
|
||||
// we have always been lying about the return type - this doesn't support batching or models
|
||||
return new AnthropicFoundry(foundryArgs) as unknown as Anthropic
|
||||
}
|
||||
if (isEnvTruthy(process.env.CLAUDE_CODE_USE_VERTEX)) {
|
||||
// Refresh GCP credentials if gcpAuthRefresh is configured and credentials are expired
|
||||
// This is similar to how we handle AWS credential refresh for Bedrock
|
||||
if (!isEnvTruthy(process.env.CLAUDE_CODE_SKIP_VERTEX_AUTH)) {
|
||||
await refreshGcpCredentialsIfNeeded()
|
||||
}
|
||||
|
||||
const [{ AnthropicVertex }, { GoogleAuth }] = await Promise.all([
|
||||
import('@anthropic-ai/vertex-sdk'),
|
||||
import('google-auth-library'),
|
||||
])
|
||||
// TODO: Cache either GoogleAuth instance or AuthClient to improve performance
|
||||
// Currently we create a new GoogleAuth instance for every getAnthropicClient() call
|
||||
// This could cause repeated authentication flows and metadata server checks
|
||||
// However, caching needs careful handling of:
|
||||
// - Credential refresh/expiration
|
||||
// - Environment variable changes (GOOGLE_APPLICATION_CREDENTIALS, project vars)
|
||||
// - Cross-request auth state management
|
||||
// See: https://github.com/googleapis/google-auth-library-nodejs/issues/390 for caching challenges
|
||||
|
||||
// Prevent metadata server timeout by providing projectId as fallback
|
||||
// google-auth-library checks project ID in this order:
|
||||
// 1. Environment variables (GCLOUD_PROJECT, GOOGLE_CLOUD_PROJECT, etc.)
|
||||
// 2. Credential files (service account JSON, ADC file)
|
||||
// 3. gcloud config
|
||||
// 4. GCE metadata server (causes 12s timeout outside GCP)
|
||||
//
|
||||
// We only set projectId if user hasn't configured other discovery methods
|
||||
// to avoid interfering with their existing auth setup
|
||||
|
||||
// Check project environment variables in same order as google-auth-library
|
||||
// See: https://github.com/googleapis/google-auth-library-nodejs/blob/main/src/auth/googleauth.ts
|
||||
const hasProjectEnvVar =
|
||||
process.env['GCLOUD_PROJECT'] ||
|
||||
process.env['GOOGLE_CLOUD_PROJECT'] ||
|
||||
process.env['gcloud_project'] ||
|
||||
process.env['google_cloud_project']
|
||||
|
||||
// Check for credential file paths (service account or ADC)
|
||||
// Note: We're checking both standard and lowercase variants to be safe,
|
||||
// though we should verify what google-auth-library actually checks
|
||||
const hasKeyFile =
|
||||
process.env['GOOGLE_APPLICATION_CREDENTIALS'] ||
|
||||
process.env['google_application_credentials']
|
||||
|
||||
const googleAuth = isEnvTruthy(process.env.CLAUDE_CODE_SKIP_VERTEX_AUTH)
|
||||
? ({
|
||||
// Mock GoogleAuth for testing/proxy scenarios
|
||||
getClient: () => ({
|
||||
getRequestHeaders: () => ({}),
|
||||
}),
|
||||
} as unknown as GoogleAuth)
|
||||
: new GoogleAuth({
|
||||
scopes: ['https://www.googleapis.com/auth/cloud-platform'],
|
||||
// Only use ANTHROPIC_VERTEX_PROJECT_ID as last resort fallback
|
||||
// This prevents the 12-second metadata server timeout when:
|
||||
// - No project env vars are set AND
|
||||
// - No credential keyfile is specified AND
|
||||
// - ADC file exists but lacks project_id field
|
||||
//
|
||||
// Risk: If auth project != API target project, this could cause billing/audit issues
|
||||
// Mitigation: Users can set GOOGLE_CLOUD_PROJECT to override
|
||||
...(hasProjectEnvVar || hasKeyFile
|
||||
? {}
|
||||
: {
|
||||
projectId: process.env.ANTHROPIC_VERTEX_PROJECT_ID,
|
||||
}),
|
||||
})
|
||||
|
||||
const vertexArgs: ConstructorParameters<typeof AnthropicVertex>[0] = {
|
||||
...ARGS,
|
||||
region: getVertexRegionForModel(model),
|
||||
googleAuth,
|
||||
...(isDebugToStdErr() && { logger: createStderrLogger() }),
|
||||
}
|
||||
// we have always been lying about the return type - this doesn't support batching or models
|
||||
return new AnthropicVertex(vertexArgs) as unknown as Anthropic
|
||||
}
|
||||
|
||||
// Determine authentication method based on available tokens
|
||||
const clientConfig: ConstructorParameters<typeof Anthropic>[0] = {
|
||||
apiKey: isClaudeAISubscriber() ? null : apiKey || getAnthropicApiKey(),
|
||||
authToken: isClaudeAISubscriber()
|
||||
? getClaudeAIOAuthTokens()?.accessToken
|
||||
: undefined,
|
||||
// Set baseURL from OAuth config when using staging OAuth
|
||||
...(process.env.USER_TYPE === 'ant' &&
|
||||
isEnvTruthy(process.env.USE_STAGING_OAUTH)
|
||||
? { baseURL: getOauthConfig().BASE_API_URL }
|
||||
: {}),
|
||||
...ARGS,
|
||||
...(isDebugToStdErr() && { logger: createStderrLogger() }),
|
||||
}
|
||||
|
||||
return new Anthropic(clientConfig)
|
||||
}
|
||||
|
||||
async function configureApiKeyHeaders(
|
||||
headers: Record<string, string>,
|
||||
isNonInteractiveSession: boolean,
|
||||
): Promise<void> {
|
||||
const token =
|
||||
process.env.ANTHROPIC_AUTH_TOKEN ||
|
||||
(await getApiKeyFromApiKeyHelper(isNonInteractiveSession))
|
||||
if (token) {
|
||||
headers['Authorization'] = `Bearer ${token}`
|
||||
}
|
||||
}
|
||||
|
||||
function getCustomHeaders(): Record<string, string> {
|
||||
const customHeaders: Record<string, string> = {}
|
||||
const customHeadersEnv = process.env.ANTHROPIC_CUSTOM_HEADERS
|
||||
|
||||
if (!customHeadersEnv) return customHeaders
|
||||
|
||||
// Split by newlines to support multiple headers
|
||||
const headerStrings = customHeadersEnv.split(/\n|\r\n/)
|
||||
|
||||
for (const headerString of headerStrings) {
|
||||
if (!headerString.trim()) continue
|
||||
|
||||
// Parse header in format "Name: Value" (curl style). Split on first `:`
|
||||
// then trim — avoids regex backtracking on malformed long header lines.
|
||||
const colonIdx = headerString.indexOf(':')
|
||||
if (colonIdx === -1) continue
|
||||
const name = headerString.slice(0, colonIdx).trim()
|
||||
const value = headerString.slice(colonIdx + 1).trim()
|
||||
if (name) {
|
||||
customHeaders[name] = value
|
||||
}
|
||||
}
|
||||
|
||||
return customHeaders
|
||||
}
|
||||
|
||||
export const CLIENT_REQUEST_ID_HEADER = 'x-client-request-id'
|
||||
|
||||
function buildFetch(
|
||||
fetchOverride: ClientOptions['fetch'],
|
||||
source: string | undefined,
|
||||
): ClientOptions['fetch'] {
|
||||
// eslint-disable-next-line eslint-plugin-n/no-unsupported-features/node-builtins
|
||||
const inner = fetchOverride ?? globalThis.fetch
|
||||
// Only send to the first-party API — Bedrock/Vertex/Foundry don't log it
|
||||
// and unknown headers risk rejection by strict proxies (inc-4029 class).
|
||||
const injectClientRequestId =
|
||||
getAPIProvider() === 'firstParty' && isFirstPartyAnthropicBaseUrl()
|
||||
return (input, init) => {
|
||||
// eslint-disable-next-line eslint-plugin-n/no-unsupported-features/node-builtins
|
||||
const headers = new Headers(init?.headers)
|
||||
// Generate a client-side request ID so timeouts (which return no server
|
||||
// request ID) can still be correlated with server logs by the API team.
|
||||
// Callers that want to track the ID themselves can pre-set the header.
|
||||
if (injectClientRequestId && !headers.has(CLIENT_REQUEST_ID_HEADER)) {
|
||||
headers.set(CLIENT_REQUEST_ID_HEADER, randomUUID())
|
||||
}
|
||||
try {
|
||||
// eslint-disable-next-line eslint-plugin-n/no-unsupported-features/node-builtins
|
||||
const url = input instanceof Request ? input.url : String(input)
|
||||
const id = headers.get(CLIENT_REQUEST_ID_HEADER)
|
||||
logForDebugging(
|
||||
`[API REQUEST] ${new URL(url).pathname}${id ? ` ${CLIENT_REQUEST_ID_HEADER}=${id}` : ''} source=${source ?? 'unknown'}`,
|
||||
)
|
||||
} catch {
|
||||
// never let logging crash the fetch
|
||||
}
|
||||
return inner(input, { ...init, headers })
|
||||
}
|
||||
}
|
||||
226
src/services/api/dumpPrompts.ts
Normal file
226
src/services/api/dumpPrompts.ts
Normal file
@@ -0,0 +1,226 @@
|
||||
import type { ClientOptions } from '@anthropic-ai/sdk'
|
||||
import { createHash } from 'crypto'
|
||||
import { promises as fs } from 'fs'
|
||||
import { dirname, join } from 'path'
|
||||
import { getSessionId } from 'src/bootstrap/state.js'
|
||||
import { getClaudeConfigHomeDir } from '../../utils/envUtils.js'
|
||||
import { jsonParse, jsonStringify } from '../../utils/slowOperations.js'
|
||||
|
||||
function hashString(str: string): string {
|
||||
return createHash('sha256').update(str).digest('hex')
|
||||
}
|
||||
|
||||
// Cache last few API requests for ant users (e.g., for /issue command)
|
||||
const MAX_CACHED_REQUESTS = 5
|
||||
const cachedApiRequests: Array<{ timestamp: string; request: unknown }> = []
|
||||
|
||||
type DumpState = {
|
||||
initialized: boolean
|
||||
messageCountSeen: number
|
||||
lastInitDataHash: string
|
||||
// Cheap proxy for change detection — skips the expensive stringify+hash
|
||||
// when model/tools/system are structurally identical to the last call.
|
||||
lastInitFingerprint: string
|
||||
}
|
||||
|
||||
// Track state per session to avoid duplicating data
|
||||
const dumpState = new Map<string, DumpState>()
|
||||
|
||||
export function getLastApiRequests(): Array<{
|
||||
timestamp: string
|
||||
request: unknown
|
||||
}> {
|
||||
return [...cachedApiRequests]
|
||||
}
|
||||
|
||||
export function clearApiRequestCache(): void {
|
||||
cachedApiRequests.length = 0
|
||||
}
|
||||
|
||||
export function clearDumpState(agentIdOrSessionId: string): void {
|
||||
dumpState.delete(agentIdOrSessionId)
|
||||
}
|
||||
|
||||
export function clearAllDumpState(): void {
|
||||
dumpState.clear()
|
||||
}
|
||||
|
||||
export function addApiRequestToCache(requestData: unknown): void {
|
||||
if (process.env.USER_TYPE !== 'ant') return
|
||||
cachedApiRequests.push({
|
||||
timestamp: new Date().toISOString(),
|
||||
request: requestData,
|
||||
})
|
||||
if (cachedApiRequests.length > MAX_CACHED_REQUESTS) {
|
||||
cachedApiRequests.shift()
|
||||
}
|
||||
}
|
||||
|
||||
export function getDumpPromptsPath(agentIdOrSessionId?: string): string {
|
||||
return join(
|
||||
getClaudeConfigHomeDir(),
|
||||
'dump-prompts',
|
||||
`${agentIdOrSessionId ?? getSessionId()}.jsonl`,
|
||||
)
|
||||
}
|
||||
|
||||
function appendToFile(filePath: string, entries: string[]): void {
|
||||
if (entries.length === 0) return
|
||||
fs.mkdir(dirname(filePath), { recursive: true })
|
||||
.then(() => fs.appendFile(filePath, entries.join('\n') + '\n'))
|
||||
.catch(() => {})
|
||||
}
|
||||
|
||||
function initFingerprint(req: Record<string, unknown>): string {
|
||||
const tools = req.tools as Array<{ name?: string }> | undefined
|
||||
const system = req.system as unknown[] | string | undefined
|
||||
const sysLen =
|
||||
typeof system === 'string'
|
||||
? system.length
|
||||
: Array.isArray(system)
|
||||
? system.reduce(
|
||||
(n: number, b) => n + ((b as { text?: string }).text?.length ?? 0),
|
||||
0,
|
||||
)
|
||||
: 0
|
||||
const toolNames = tools?.map(t => t.name ?? '').join(',') ?? ''
|
||||
return `${req.model}|${toolNames}|${sysLen}`
|
||||
}
|
||||
|
||||
function dumpRequest(
|
||||
body: string,
|
||||
ts: string,
|
||||
state: DumpState,
|
||||
filePath: string,
|
||||
): void {
|
||||
try {
|
||||
const req = jsonParse(body) as Record<string, unknown>
|
||||
addApiRequestToCache(req)
|
||||
|
||||
if (process.env.USER_TYPE !== 'ant') return
|
||||
const entries: string[] = []
|
||||
const messages = (req.messages ?? []) as Array<{ role?: string }>
|
||||
|
||||
// Write init data (system, tools, metadata) on first request,
|
||||
// and a system_update entry whenever it changes.
|
||||
// Cheap fingerprint first: system+tools don't change between turns,
|
||||
// so skip the 300ms stringify when the shape is unchanged.
|
||||
const fingerprint = initFingerprint(req)
|
||||
if (!state.initialized || fingerprint !== state.lastInitFingerprint) {
|
||||
const { messages: _, ...initData } = req
|
||||
const initDataStr = jsonStringify(initData)
|
||||
const initDataHash = hashString(initDataStr)
|
||||
state.lastInitFingerprint = fingerprint
|
||||
if (!state.initialized) {
|
||||
state.initialized = true
|
||||
state.lastInitDataHash = initDataHash
|
||||
// Reuse initDataStr rather than re-serializing initData inside a wrapper.
|
||||
// timestamp from toISOString() contains no chars needing JSON escaping.
|
||||
entries.push(
|
||||
`{"type":"init","timestamp":"${ts}","data":${initDataStr}}`,
|
||||
)
|
||||
} else if (initDataHash !== state.lastInitDataHash) {
|
||||
state.lastInitDataHash = initDataHash
|
||||
entries.push(
|
||||
`{"type":"system_update","timestamp":"${ts}","data":${initDataStr}}`,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Write only new user messages (assistant messages captured in response)
|
||||
for (const msg of messages.slice(state.messageCountSeen)) {
|
||||
if (msg.role === 'user') {
|
||||
entries.push(
|
||||
jsonStringify({ type: 'message', timestamp: ts, data: msg }),
|
||||
)
|
||||
}
|
||||
}
|
||||
state.messageCountSeen = messages.length
|
||||
|
||||
appendToFile(filePath, entries)
|
||||
} catch {
|
||||
// Ignore parsing errors
|
||||
}
|
||||
}
|
||||
|
||||
export function createDumpPromptsFetch(
|
||||
agentIdOrSessionId: string,
|
||||
): ClientOptions['fetch'] {
|
||||
const filePath = getDumpPromptsPath(agentIdOrSessionId)
|
||||
|
||||
return async (input: RequestInfo | URL, init?: RequestInit) => {
|
||||
const state = dumpState.get(agentIdOrSessionId) ?? {
|
||||
initialized: false,
|
||||
messageCountSeen: 0,
|
||||
lastInitDataHash: '',
|
||||
lastInitFingerprint: '',
|
||||
}
|
||||
dumpState.set(agentIdOrSessionId, state)
|
||||
|
||||
let timestamp: string | undefined
|
||||
|
||||
if (init?.method === 'POST' && init.body) {
|
||||
timestamp = new Date().toISOString()
|
||||
// Parsing + stringifying the request (system prompt + tool schemas = MBs)
|
||||
// takes hundreds of ms. Defer so it doesn't block the actual API call —
|
||||
// this is debug tooling for /issue, not on the critical path.
|
||||
setImmediate(dumpRequest, init.body as string, timestamp, state, filePath)
|
||||
}
|
||||
|
||||
// eslint-disable-next-line eslint-plugin-n/no-unsupported-features/node-builtins
|
||||
const response = await globalThis.fetch(input, init)
|
||||
|
||||
// Save response async
|
||||
if (timestamp && response.ok && process.env.USER_TYPE === 'ant') {
|
||||
const cloned = response.clone()
|
||||
void (async () => {
|
||||
try {
|
||||
const isStreaming = cloned.headers
|
||||
.get('content-type')
|
||||
?.includes('text/event-stream')
|
||||
|
||||
let data: unknown
|
||||
if (isStreaming && cloned.body) {
|
||||
// Parse SSE stream into chunks
|
||||
const reader = cloned.body.getReader()
|
||||
const decoder = new TextDecoder()
|
||||
let buffer = ''
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
buffer += decoder.decode(value, { stream: true })
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock()
|
||||
}
|
||||
const chunks: unknown[] = []
|
||||
for (const event of buffer.split('\n\n')) {
|
||||
for (const line of event.split('\n')) {
|
||||
if (line.startsWith('data: ') && line !== 'data: [DONE]') {
|
||||
try {
|
||||
chunks.push(jsonParse(line.slice(6)))
|
||||
} catch {
|
||||
// Ignore parse errors
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
data = { stream: true, chunks }
|
||||
} else {
|
||||
data = await cloned.json()
|
||||
}
|
||||
|
||||
await fs.appendFile(
|
||||
filePath,
|
||||
jsonStringify({ type: 'response', timestamp, data }) + '\n',
|
||||
)
|
||||
} catch {
|
||||
// Best effort
|
||||
}
|
||||
})()
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
}
|
||||
22
src/services/api/emptyUsage.ts
Normal file
22
src/services/api/emptyUsage.ts
Normal file
@@ -0,0 +1,22 @@
|
||||
import type { NonNullableUsage } from '../../entrypoints/sdk/sdkUtilityTypes.js'
|
||||
|
||||
/**
|
||||
* Zero-initialized usage object. Extracted from logging.ts so that
|
||||
* bridge/replBridge.ts can import it without transitively pulling in
|
||||
* api/errors.ts → utils/messages.ts → BashTool.tsx → the world.
|
||||
*/
|
||||
export const EMPTY_USAGE: Readonly<NonNullableUsage> = {
|
||||
input_tokens: 0,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
server_tool_use: { web_search_requests: 0, web_fetch_requests: 0 },
|
||||
service_tier: 'standard',
|
||||
cache_creation: {
|
||||
ephemeral_1h_input_tokens: 0,
|
||||
ephemeral_5m_input_tokens: 0,
|
||||
},
|
||||
inference_geo: '',
|
||||
iterations: [],
|
||||
speed: 'standard',
|
||||
}
|
||||
260
src/services/api/errorUtils.ts
Normal file
260
src/services/api/errorUtils.ts
Normal file
@@ -0,0 +1,260 @@
|
||||
import type { APIError } from '@anthropic-ai/sdk'
|
||||
|
||||
// SSL/TLS error codes from OpenSSL (used by both Node.js and Bun)
|
||||
// See: https://www.openssl.org/docs/man3.1/man3/X509_STORE_CTX_get_error.html
|
||||
const SSL_ERROR_CODES = new Set([
|
||||
// Certificate verification errors
|
||||
'UNABLE_TO_VERIFY_LEAF_SIGNATURE',
|
||||
'UNABLE_TO_GET_ISSUER_CERT',
|
||||
'UNABLE_TO_GET_ISSUER_CERT_LOCALLY',
|
||||
'CERT_SIGNATURE_FAILURE',
|
||||
'CERT_NOT_YET_VALID',
|
||||
'CERT_HAS_EXPIRED',
|
||||
'CERT_REVOKED',
|
||||
'CERT_REJECTED',
|
||||
'CERT_UNTRUSTED',
|
||||
// Self-signed certificate errors
|
||||
'DEPTH_ZERO_SELF_SIGNED_CERT',
|
||||
'SELF_SIGNED_CERT_IN_CHAIN',
|
||||
// Chain errors
|
||||
'CERT_CHAIN_TOO_LONG',
|
||||
'PATH_LENGTH_EXCEEDED',
|
||||
// Hostname/altname errors
|
||||
'ERR_TLS_CERT_ALTNAME_INVALID',
|
||||
'HOSTNAME_MISMATCH',
|
||||
// TLS handshake errors
|
||||
'ERR_TLS_HANDSHAKE_TIMEOUT',
|
||||
'ERR_SSL_WRONG_VERSION_NUMBER',
|
||||
'ERR_SSL_DECRYPTION_FAILED_OR_BAD_RECORD_MAC',
|
||||
])
|
||||
|
||||
export type ConnectionErrorDetails = {
|
||||
code: string
|
||||
message: string
|
||||
isSSLError: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts connection error details from the error cause chain.
|
||||
* The Anthropic SDK wraps underlying errors in the `cause` property.
|
||||
* This function walks the cause chain to find the root error code/message.
|
||||
*/
|
||||
export function extractConnectionErrorDetails(
|
||||
error: unknown,
|
||||
): ConnectionErrorDetails | null {
|
||||
if (!error || typeof error !== 'object') {
|
||||
return null
|
||||
}
|
||||
|
||||
// Walk the cause chain to find the root error with a code
|
||||
let current: unknown = error
|
||||
const maxDepth = 5 // Prevent infinite loops
|
||||
let depth = 0
|
||||
|
||||
while (current && depth < maxDepth) {
|
||||
if (
|
||||
current instanceof Error &&
|
||||
'code' in current &&
|
||||
typeof current.code === 'string'
|
||||
) {
|
||||
const code = current.code
|
||||
const isSSLError = SSL_ERROR_CODES.has(code)
|
||||
return {
|
||||
code,
|
||||
message: current.message,
|
||||
isSSLError,
|
||||
}
|
||||
}
|
||||
|
||||
// Move to the next cause in the chain
|
||||
if (
|
||||
current instanceof Error &&
|
||||
'cause' in current &&
|
||||
current.cause !== current
|
||||
) {
|
||||
current = current.cause
|
||||
depth++
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an actionable hint for SSL/TLS errors, intended for contexts outside
|
||||
* the main API client (OAuth token exchange, preflight connectivity checks)
|
||||
* where `formatAPIError` doesn't apply.
|
||||
*
|
||||
* Motivation: enterprise users behind TLS-intercepting proxies (Zscaler et al.)
|
||||
* see OAuth complete in-browser but the CLI's token exchange silently fails
|
||||
* with a raw SSL code. Surfacing the likely fix saves a support round-trip.
|
||||
*/
|
||||
export function getSSLErrorHint(error: unknown): string | null {
|
||||
const details = extractConnectionErrorDetails(error)
|
||||
if (!details?.isSSLError) {
|
||||
return null
|
||||
}
|
||||
return `SSL certificate error (${details.code}). If you are behind a corporate proxy or TLS-intercepting firewall, set NODE_EXTRA_CA_CERTS to your CA bundle path, or ask IT to allowlist *.anthropic.com. Run /doctor for details.`
|
||||
}
|
||||
|
||||
/**
|
||||
* Strips HTML content (e.g., CloudFlare error pages) from a message string,
|
||||
* returning a user-friendly title or empty string if HTML is detected.
|
||||
* Returns the original message unchanged if no HTML is found.
|
||||
*/
|
||||
function sanitizeMessageHTML(message: string): string {
|
||||
if (message.includes('<!DOCTYPE html') || message.includes('<html')) {
|
||||
const titleMatch = message.match(/<title>([^<]+)<\/title>/)
|
||||
if (titleMatch && titleMatch[1]) {
|
||||
return titleMatch[1].trim()
|
||||
}
|
||||
return ''
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
/**
|
||||
* Detects if an error message contains HTML content (e.g., CloudFlare error pages)
|
||||
* and returns a user-friendly message instead
|
||||
*/
|
||||
export function sanitizeAPIError(apiError: APIError): string {
|
||||
const message = apiError.message
|
||||
if (!message) {
|
||||
// Sometimes message is undefined
|
||||
// TODO: figure out why
|
||||
return ''
|
||||
}
|
||||
return sanitizeMessageHTML(message)
|
||||
}
|
||||
|
||||
/**
|
||||
* Shapes of deserialized API errors from session JSONL.
|
||||
*
|
||||
* After JSON round-tripping, the SDK's APIError loses its `.message` property.
|
||||
* The actual message lives at different nesting levels depending on the provider:
|
||||
*
|
||||
* - Bedrock/proxy: `{ error: { message: "..." } }`
|
||||
* - Standard Anthropic API: `{ error: { error: { message: "..." } } }`
|
||||
* (the outer `.error` is the response body, the inner `.error` is the API error)
|
||||
*
|
||||
* See also: `getErrorMessage` in `logging.ts` which handles the same shapes.
|
||||
*/
|
||||
type NestedAPIError = {
|
||||
error?: {
|
||||
message?: string
|
||||
error?: { message?: string }
|
||||
}
|
||||
}
|
||||
|
||||
function hasNestedError(value: unknown): value is NestedAPIError {
|
||||
return (
|
||||
typeof value === 'object' &&
|
||||
value !== null &&
|
||||
'error' in value &&
|
||||
typeof value.error === 'object' &&
|
||||
value.error !== null
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract a human-readable message from a deserialized API error that lacks
|
||||
* a top-level `.message`.
|
||||
*
|
||||
* Checks two nesting levels (deeper first for specificity):
|
||||
* 1. `error.error.error.message` — standard Anthropic API shape
|
||||
* 2. `error.error.message` — Bedrock shape
|
||||
*/
|
||||
function extractNestedErrorMessage(error: APIError): string | null {
|
||||
if (!hasNestedError(error)) {
|
||||
return null
|
||||
}
|
||||
|
||||
// Access `.error` via the narrowed type so TypeScript sees the nested shape
|
||||
// instead of the SDK's `Object | undefined`.
|
||||
const narrowed: NestedAPIError = error
|
||||
const nested = narrowed.error
|
||||
|
||||
// Standard Anthropic API shape: { error: { error: { message } } }
|
||||
const deepMsg = nested?.error?.message
|
||||
if (typeof deepMsg === 'string' && deepMsg.length > 0) {
|
||||
const sanitized = sanitizeMessageHTML(deepMsg)
|
||||
if (sanitized.length > 0) {
|
||||
return sanitized
|
||||
}
|
||||
}
|
||||
|
||||
// Bedrock shape: { error: { message } }
|
||||
const msg = nested?.message
|
||||
if (typeof msg === 'string' && msg.length > 0) {
|
||||
const sanitized = sanitizeMessageHTML(msg)
|
||||
if (sanitized.length > 0) {
|
||||
return sanitized
|
||||
}
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
export function formatAPIError(error: APIError): string {
|
||||
// Extract connection error details from the cause chain
|
||||
const connectionDetails = extractConnectionErrorDetails(error)
|
||||
|
||||
if (connectionDetails) {
|
||||
const { code, isSSLError } = connectionDetails
|
||||
|
||||
// Handle timeout errors
|
||||
if (code === 'ETIMEDOUT') {
|
||||
return 'Request timed out. Check your internet connection and proxy settings'
|
||||
}
|
||||
|
||||
// Handle SSL/TLS errors with specific messages
|
||||
if (isSSLError) {
|
||||
switch (code) {
|
||||
case 'UNABLE_TO_VERIFY_LEAF_SIGNATURE':
|
||||
case 'UNABLE_TO_GET_ISSUER_CERT':
|
||||
case 'UNABLE_TO_GET_ISSUER_CERT_LOCALLY':
|
||||
return 'Unable to connect to API: SSL certificate verification failed. Check your proxy or corporate SSL certificates'
|
||||
case 'CERT_HAS_EXPIRED':
|
||||
return 'Unable to connect to API: SSL certificate has expired'
|
||||
case 'CERT_REVOKED':
|
||||
return 'Unable to connect to API: SSL certificate has been revoked'
|
||||
case 'DEPTH_ZERO_SELF_SIGNED_CERT':
|
||||
case 'SELF_SIGNED_CERT_IN_CHAIN':
|
||||
return 'Unable to connect to API: Self-signed certificate detected. Check your proxy or corporate SSL certificates'
|
||||
case 'ERR_TLS_CERT_ALTNAME_INVALID':
|
||||
case 'HOSTNAME_MISMATCH':
|
||||
return 'Unable to connect to API: SSL certificate hostname mismatch'
|
||||
case 'CERT_NOT_YET_VALID':
|
||||
return 'Unable to connect to API: SSL certificate is not yet valid'
|
||||
default:
|
||||
return `Unable to connect to API: SSL error (${code})`
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (error.message === 'Connection error.') {
|
||||
// If we have a code but it's not SSL, include it for debugging
|
||||
if (connectionDetails?.code) {
|
||||
return `Unable to connect to API (${connectionDetails.code})`
|
||||
}
|
||||
return 'Unable to connect to API. Check your internet connection'
|
||||
}
|
||||
|
||||
// Guard: when deserialized from JSONL (e.g. --resume), the error object may
|
||||
// be a plain object without a `.message` property. Return a safe fallback
|
||||
// instead of undefined, which would crash callers that access `.length`.
|
||||
if (!error.message) {
|
||||
return (
|
||||
extractNestedErrorMessage(error) ??
|
||||
`API error (status ${error.status ?? 'unknown'})`
|
||||
)
|
||||
}
|
||||
|
||||
const sanitizedMessage = sanitizeAPIError(error)
|
||||
// Use sanitized message if it's different from the original (i.e., HTML was sanitized)
|
||||
return sanitizedMessage !== error.message && sanitizedMessage.length > 0
|
||||
? sanitizedMessage
|
||||
: error.message
|
||||
}
|
||||
1207
src/services/api/errors.ts
Normal file
1207
src/services/api/errors.ts
Normal file
File diff suppressed because it is too large
Load Diff
748
src/services/api/filesApi.ts
Normal file
748
src/services/api/filesApi.ts
Normal file
@@ -0,0 +1,748 @@
|
||||
/**
|
||||
* Files API client for managing files
|
||||
*
|
||||
* This module provides functionality to download and upload files to Anthropic Public Files API.
|
||||
* Used by the Claude Code agent to download file attachments at session startup.
|
||||
*
|
||||
* API Reference: https://docs.anthropic.com/en/api/files-content
|
||||
*/
|
||||
|
||||
import axios from 'axios'
|
||||
import { randomUUID } from 'crypto'
|
||||
import * as fs from 'fs/promises'
|
||||
import * as path from 'path'
|
||||
import { count } from '../../utils/array.js'
|
||||
import { getCwd } from '../../utils/cwd.js'
|
||||
import { logForDebugging } from '../../utils/debug.js'
|
||||
import { errorMessage } from '../../utils/errors.js'
|
||||
import { logError } from '../../utils/log.js'
|
||||
import { sleep } from '../../utils/sleep.js'
|
||||
import {
|
||||
type AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
logEvent,
|
||||
} from '../analytics/index.js'
|
||||
|
||||
// Files API is currently in beta. oauth-2025-04-20 enables Bearer OAuth
|
||||
// on public-api routes (auth.py: "oauth_auth" not in beta_versions → 404).
|
||||
const FILES_API_BETA_HEADER = 'files-api-2025-04-14,oauth-2025-04-20'
|
||||
const ANTHROPIC_VERSION = '2023-06-01'
|
||||
|
||||
// API base URL - uses ANTHROPIC_BASE_URL set by env-manager for the appropriate environment
|
||||
// Falls back to public API for standalone usage
|
||||
function getDefaultApiBaseUrl(): string {
|
||||
return (
|
||||
process.env.ANTHROPIC_BASE_URL ||
|
||||
process.env.CLAUDE_CODE_API_BASE_URL ||
|
||||
'https://api.anthropic.com'
|
||||
)
|
||||
}
|
||||
|
||||
function logDebugError(message: string): void {
|
||||
logForDebugging(`[files-api] ${message}`, { level: 'error' })
|
||||
}
|
||||
|
||||
function logDebug(message: string): void {
|
||||
logForDebugging(`[files-api] ${message}`)
|
||||
}
|
||||
|
||||
/**
|
||||
* File specification parsed from CLI args
|
||||
* Format: --file=<file_id>:<relative_path>
|
||||
*/
|
||||
export type File = {
|
||||
fileId: string
|
||||
relativePath: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Configuration for the files API client
|
||||
*/
|
||||
export type FilesApiConfig = {
|
||||
/** OAuth token for authentication (from session JWT) */
|
||||
oauthToken: string
|
||||
/** Base URL for the API (default: https://api.anthropic.com) */
|
||||
baseUrl?: string
|
||||
/** Session ID for creating session-specific directories */
|
||||
sessionId: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Result of a file download operation
|
||||
*/
|
||||
export type DownloadResult = {
|
||||
fileId: string
|
||||
path: string
|
||||
success: boolean
|
||||
error?: string
|
||||
bytesWritten?: number
|
||||
}
|
||||
|
||||
const MAX_RETRIES = 3
|
||||
const BASE_DELAY_MS = 500
|
||||
const MAX_FILE_SIZE_BYTES = 500 * 1024 * 1024 // 500MB
|
||||
|
||||
/**
|
||||
* Result type for retry operations - signals whether to continue retrying
|
||||
*/
|
||||
type RetryResult<T> = { done: true; value: T } | { done: false; error?: string }
|
||||
|
||||
/**
|
||||
* Executes an operation with exponential backoff retry logic
|
||||
*
|
||||
* @param operation - Operation name for logging
|
||||
* @param attemptFn - Function to execute on each attempt, returns RetryResult
|
||||
* @returns The successful result value
|
||||
* @throws Error if all retries exhausted
|
||||
*/
|
||||
async function retryWithBackoff<T>(
|
||||
operation: string,
|
||||
attemptFn: (attempt: number) => Promise<RetryResult<T>>,
|
||||
): Promise<T> {
|
||||
let lastError = ''
|
||||
|
||||
for (let attempt = 1; attempt <= MAX_RETRIES; attempt++) {
|
||||
const result = await attemptFn(attempt)
|
||||
|
||||
if (result.done) {
|
||||
return result.value
|
||||
}
|
||||
|
||||
lastError = result.error || `${operation} failed`
|
||||
logDebug(
|
||||
`${operation} attempt ${attempt}/${MAX_RETRIES} failed: ${lastError}`,
|
||||
)
|
||||
|
||||
if (attempt < MAX_RETRIES) {
|
||||
const delayMs = BASE_DELAY_MS * Math.pow(2, attempt - 1)
|
||||
logDebug(`Retrying ${operation} in ${delayMs}ms...`)
|
||||
await sleep(delayMs)
|
||||
}
|
||||
}
|
||||
|
||||
throw new Error(`${lastError} after ${MAX_RETRIES} attempts`)
|
||||
}
|
||||
|
||||
/**
|
||||
* Downloads a single file from the Anthropic Public Files API
|
||||
*
|
||||
* @param fileId - The file ID (e.g., "file_011CNha8iCJcU1wXNR6q4V8w")
|
||||
* @param config - Files API configuration
|
||||
* @returns The file content as a Buffer
|
||||
*/
|
||||
export async function downloadFile(
|
||||
fileId: string,
|
||||
config: FilesApiConfig,
|
||||
): Promise<Buffer> {
|
||||
const baseUrl = config.baseUrl || getDefaultApiBaseUrl()
|
||||
const url = `${baseUrl}/v1/files/${fileId}/content`
|
||||
|
||||
const headers = {
|
||||
Authorization: `Bearer ${config.oauthToken}`,
|
||||
'anthropic-version': ANTHROPIC_VERSION,
|
||||
'anthropic-beta': FILES_API_BETA_HEADER,
|
||||
}
|
||||
|
||||
logDebug(`Downloading file ${fileId} from ${url}`)
|
||||
|
||||
return retryWithBackoff(`Download file ${fileId}`, async () => {
|
||||
try {
|
||||
const response = await axios.get(url, {
|
||||
headers,
|
||||
responseType: 'arraybuffer',
|
||||
timeout: 60000, // 60 second timeout for large files
|
||||
validateStatus: status => status < 500,
|
||||
})
|
||||
|
||||
if (response.status === 200) {
|
||||
logDebug(`Downloaded file ${fileId} (${response.data.length} bytes)`)
|
||||
return { done: true, value: Buffer.from(response.data) }
|
||||
}
|
||||
|
||||
// Non-retriable errors - throw immediately
|
||||
if (response.status === 404) {
|
||||
throw new Error(`File not found: ${fileId}`)
|
||||
}
|
||||
if (response.status === 401) {
|
||||
throw new Error('Authentication failed: invalid or missing API key')
|
||||
}
|
||||
if (response.status === 403) {
|
||||
throw new Error(`Access denied to file: ${fileId}`)
|
||||
}
|
||||
|
||||
return { done: false, error: `status ${response.status}` }
|
||||
} catch (error) {
|
||||
if (!axios.isAxiosError(error)) {
|
||||
throw error
|
||||
}
|
||||
return { done: false, error: error.message }
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Normalizes a relative path, strips redundant prefixes, and builds the full
|
||||
* download path under {basePath}/{session_id}/uploads/.
|
||||
* Returns null if the path is invalid (e.g., path traversal).
|
||||
*/
|
||||
export function buildDownloadPath(
|
||||
basePath: string,
|
||||
sessionId: string,
|
||||
relativePath: string,
|
||||
): string | null {
|
||||
const normalized = path.normalize(relativePath)
|
||||
if (normalized.startsWith('..')) {
|
||||
logDebugError(
|
||||
`Invalid file path: ${relativePath}. Path must not traverse above workspace`,
|
||||
)
|
||||
return null
|
||||
}
|
||||
|
||||
const uploadsBase = path.join(basePath, sessionId, 'uploads')
|
||||
const redundantPrefixes = [
|
||||
path.join(basePath, sessionId, 'uploads') + path.sep,
|
||||
path.sep + 'uploads' + path.sep,
|
||||
]
|
||||
const matchedPrefix = redundantPrefixes.find(p => normalized.startsWith(p))
|
||||
const cleanPath = matchedPrefix
|
||||
? normalized.slice(matchedPrefix.length)
|
||||
: normalized
|
||||
return path.join(uploadsBase, cleanPath)
|
||||
}
|
||||
|
||||
/**
|
||||
* Downloads a file and saves it to the session-specific workspace directory
|
||||
*
|
||||
* @param attachment - The file attachment to download
|
||||
* @param config - Files API configuration
|
||||
* @returns Download result with success/failure status
|
||||
*/
|
||||
export async function downloadAndSaveFile(
|
||||
attachment: File,
|
||||
config: FilesApiConfig,
|
||||
): Promise<DownloadResult> {
|
||||
const { fileId, relativePath } = attachment
|
||||
const fullPath = buildDownloadPath(getCwd(), config.sessionId, relativePath)
|
||||
|
||||
if (!fullPath) {
|
||||
return {
|
||||
fileId,
|
||||
path: '',
|
||||
success: false,
|
||||
error: `Invalid file path: ${relativePath}`,
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
// Download the file content
|
||||
const content = await downloadFile(fileId, config)
|
||||
|
||||
// Ensure the parent directory exists
|
||||
const parentDir = path.dirname(fullPath)
|
||||
await fs.mkdir(parentDir, { recursive: true })
|
||||
|
||||
// Write the file
|
||||
await fs.writeFile(fullPath, content)
|
||||
|
||||
logDebug(`Saved file ${fileId} to ${fullPath} (${content.length} bytes)`)
|
||||
|
||||
return {
|
||||
fileId,
|
||||
path: fullPath,
|
||||
success: true,
|
||||
bytesWritten: content.length,
|
||||
}
|
||||
} catch (error) {
|
||||
logDebugError(`Failed to download file ${fileId}: ${errorMessage(error)}`)
|
||||
if (error instanceof Error) {
|
||||
logError(error)
|
||||
}
|
||||
|
||||
return {
|
||||
fileId,
|
||||
path: fullPath,
|
||||
success: false,
|
||||
error: errorMessage(error),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Default concurrency limit for parallel downloads
|
||||
const DEFAULT_CONCURRENCY = 5
|
||||
|
||||
/**
|
||||
* Execute promises with limited concurrency
|
||||
*
|
||||
* @param items - Items to process
|
||||
* @param fn - Async function to apply to each item
|
||||
* @param concurrency - Maximum concurrent operations
|
||||
* @returns Results in the same order as input items
|
||||
*/
|
||||
async function parallelWithLimit<T, R>(
|
||||
items: T[],
|
||||
fn: (item: T, index: number) => Promise<R>,
|
||||
concurrency: number,
|
||||
): Promise<R[]> {
|
||||
const results: R[] = new Array(items.length)
|
||||
let currentIndex = 0
|
||||
|
||||
async function worker(): Promise<void> {
|
||||
while (currentIndex < items.length) {
|
||||
const index = currentIndex++
|
||||
const item = items[index]
|
||||
if (item !== undefined) {
|
||||
results[index] = await fn(item, index)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Start workers up to the concurrency limit
|
||||
const workers: Promise<void>[] = []
|
||||
const workerCount = Math.min(concurrency, items.length)
|
||||
for (let i = 0; i < workerCount; i++) {
|
||||
workers.push(worker())
|
||||
}
|
||||
|
||||
await Promise.all(workers)
|
||||
return results
|
||||
}
|
||||
|
||||
/**
|
||||
* Downloads all file attachments for a session in parallel
|
||||
*
|
||||
* @param attachments - List of file attachments to download
|
||||
* @param config - Files API configuration
|
||||
* @param concurrency - Maximum concurrent downloads (default: 5)
|
||||
* @returns Array of download results in the same order as input
|
||||
*/
|
||||
export async function downloadSessionFiles(
|
||||
files: File[],
|
||||
config: FilesApiConfig,
|
||||
concurrency: number = DEFAULT_CONCURRENCY,
|
||||
): Promise<DownloadResult[]> {
|
||||
if (files.length === 0) {
|
||||
return []
|
||||
}
|
||||
|
||||
logDebug(
|
||||
`Downloading ${files.length} file(s) for session ${config.sessionId}`,
|
||||
)
|
||||
const startTime = Date.now()
|
||||
|
||||
// Download files in parallel with concurrency limit
|
||||
const results = await parallelWithLimit(
|
||||
files,
|
||||
file => downloadAndSaveFile(file, config),
|
||||
concurrency,
|
||||
)
|
||||
|
||||
const elapsedMs = Date.now() - startTime
|
||||
const successCount = count(results, r => r.success)
|
||||
logDebug(
|
||||
`Downloaded ${successCount}/${files.length} file(s) in ${elapsedMs}ms`,
|
||||
)
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Upload Functions (BYOC mode)
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Result of a file upload operation
|
||||
*/
|
||||
export type UploadResult =
|
||||
| {
|
||||
path: string
|
||||
fileId: string
|
||||
size: number
|
||||
success: true
|
||||
}
|
||||
| {
|
||||
path: string
|
||||
error: string
|
||||
success: false
|
||||
}
|
||||
|
||||
/**
|
||||
* Upload a single file to the Files API (BYOC mode)
|
||||
*
|
||||
* Size validation is performed after reading the file to avoid TOCTOU race
|
||||
* conditions where the file size could change between initial check and upload.
|
||||
*
|
||||
* @param filePath - Absolute path to the file to upload
|
||||
* @param relativePath - Relative path for the file (used as filename in API)
|
||||
* @param config - Files API configuration
|
||||
* @returns Upload result with success/failure status
|
||||
*/
|
||||
export async function uploadFile(
|
||||
filePath: string,
|
||||
relativePath: string,
|
||||
config: FilesApiConfig,
|
||||
opts?: { signal?: AbortSignal },
|
||||
): Promise<UploadResult> {
|
||||
const baseUrl = config.baseUrl || getDefaultApiBaseUrl()
|
||||
const url = `${baseUrl}/v1/files`
|
||||
|
||||
const headers = {
|
||||
Authorization: `Bearer ${config.oauthToken}`,
|
||||
'anthropic-version': ANTHROPIC_VERSION,
|
||||
'anthropic-beta': FILES_API_BETA_HEADER,
|
||||
}
|
||||
|
||||
logDebug(`Uploading file ${filePath} as ${relativePath}`)
|
||||
|
||||
// Read file content first (outside retry loop since it's not a network operation)
|
||||
let content: Buffer
|
||||
try {
|
||||
content = await fs.readFile(filePath)
|
||||
} catch (error) {
|
||||
logEvent('tengu_file_upload_failed', {
|
||||
error_type:
|
||||
'file_read' as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
return {
|
||||
path: relativePath,
|
||||
error: errorMessage(error),
|
||||
success: false,
|
||||
}
|
||||
}
|
||||
|
||||
const fileSize = content.length
|
||||
|
||||
if (fileSize > MAX_FILE_SIZE_BYTES) {
|
||||
logEvent('tengu_file_upload_failed', {
|
||||
error_type:
|
||||
'file_too_large' as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
return {
|
||||
path: relativePath,
|
||||
error: `File exceeds maximum size of ${MAX_FILE_SIZE_BYTES} bytes (actual: ${fileSize})`,
|
||||
success: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Use crypto.randomUUID for boundary to avoid collisions when uploads start same millisecond
|
||||
const boundary = `----FormBoundary${randomUUID()}`
|
||||
const filename = path.basename(relativePath)
|
||||
|
||||
// Build the multipart body
|
||||
const bodyParts: Buffer[] = []
|
||||
|
||||
// File part
|
||||
bodyParts.push(
|
||||
Buffer.from(
|
||||
`--${boundary}\r\n` +
|
||||
`Content-Disposition: form-data; name="file"; filename="${filename}"\r\n` +
|
||||
`Content-Type: application/octet-stream\r\n\r\n`,
|
||||
),
|
||||
)
|
||||
bodyParts.push(content)
|
||||
bodyParts.push(Buffer.from('\r\n'))
|
||||
|
||||
// Purpose part
|
||||
bodyParts.push(
|
||||
Buffer.from(
|
||||
`--${boundary}\r\n` +
|
||||
`Content-Disposition: form-data; name="purpose"\r\n\r\n` +
|
||||
`user_data\r\n`,
|
||||
),
|
||||
)
|
||||
|
||||
// End boundary
|
||||
bodyParts.push(Buffer.from(`--${boundary}--\r\n`))
|
||||
|
||||
const body = Buffer.concat(bodyParts)
|
||||
|
||||
try {
|
||||
return await retryWithBackoff(`Upload file ${relativePath}`, async () => {
|
||||
try {
|
||||
const response = await axios.post(url, body, {
|
||||
headers: {
|
||||
...headers,
|
||||
'Content-Type': `multipart/form-data; boundary=${boundary}`,
|
||||
'Content-Length': body.length.toString(),
|
||||
},
|
||||
timeout: 120000, // 2 minute timeout for uploads
|
||||
signal: opts?.signal,
|
||||
validateStatus: status => status < 500,
|
||||
})
|
||||
|
||||
if (response.status === 200 || response.status === 201) {
|
||||
const fileId = response.data?.id
|
||||
if (!fileId) {
|
||||
return {
|
||||
done: false,
|
||||
error: 'Upload succeeded but no file ID returned',
|
||||
}
|
||||
}
|
||||
logDebug(`Uploaded file ${filePath} -> ${fileId} (${fileSize} bytes)`)
|
||||
return {
|
||||
done: true,
|
||||
value: {
|
||||
path: relativePath,
|
||||
fileId,
|
||||
size: fileSize,
|
||||
success: true as const,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Non-retriable errors - throw to exit retry loop
|
||||
if (response.status === 401) {
|
||||
logEvent('tengu_file_upload_failed', {
|
||||
error_type:
|
||||
'auth' as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
throw new UploadNonRetriableError(
|
||||
'Authentication failed: invalid or missing API key',
|
||||
)
|
||||
}
|
||||
|
||||
if (response.status === 403) {
|
||||
logEvent('tengu_file_upload_failed', {
|
||||
error_type:
|
||||
'forbidden' as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
throw new UploadNonRetriableError('Access denied for upload')
|
||||
}
|
||||
|
||||
if (response.status === 413) {
|
||||
logEvent('tengu_file_upload_failed', {
|
||||
error_type:
|
||||
'size' as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
throw new UploadNonRetriableError('File too large for upload')
|
||||
}
|
||||
|
||||
return { done: false, error: `status ${response.status}` }
|
||||
} catch (error) {
|
||||
// Non-retriable errors propagate up
|
||||
if (error instanceof UploadNonRetriableError) {
|
||||
throw error
|
||||
}
|
||||
if (axios.isCancel(error)) {
|
||||
throw new UploadNonRetriableError('Upload canceled')
|
||||
}
|
||||
// Network errors are retriable
|
||||
if (axios.isAxiosError(error)) {
|
||||
return { done: false, error: error.message }
|
||||
}
|
||||
throw error
|
||||
}
|
||||
})
|
||||
} catch (error) {
|
||||
if (error instanceof UploadNonRetriableError) {
|
||||
return {
|
||||
path: relativePath,
|
||||
error: error.message,
|
||||
success: false,
|
||||
}
|
||||
}
|
||||
logEvent('tengu_file_upload_failed', {
|
||||
error_type:
|
||||
'network' as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
return {
|
||||
path: relativePath,
|
||||
error: errorMessage(error),
|
||||
success: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Error class for non-retriable upload failures */
|
||||
class UploadNonRetriableError extends Error {
|
||||
constructor(message: string) {
|
||||
super(message)
|
||||
this.name = 'UploadNonRetriableError'
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Upload multiple files in parallel with concurrency limit (BYOC mode)
|
||||
*
|
||||
* @param files - Array of files to upload (path and relativePath)
|
||||
* @param config - Files API configuration
|
||||
* @param concurrency - Maximum concurrent uploads (default: 5)
|
||||
* @returns Array of upload results in the same order as input
|
||||
*/
|
||||
export async function uploadSessionFiles(
|
||||
files: Array<{ path: string; relativePath: string }>,
|
||||
config: FilesApiConfig,
|
||||
concurrency: number = DEFAULT_CONCURRENCY,
|
||||
): Promise<UploadResult[]> {
|
||||
if (files.length === 0) {
|
||||
return []
|
||||
}
|
||||
|
||||
logDebug(`Uploading ${files.length} file(s) for session ${config.sessionId}`)
|
||||
const startTime = Date.now()
|
||||
|
||||
const results = await parallelWithLimit(
|
||||
files,
|
||||
file => uploadFile(file.path, file.relativePath, config),
|
||||
concurrency,
|
||||
)
|
||||
|
||||
const elapsedMs = Date.now() - startTime
|
||||
const successCount = count(results, r => r.success)
|
||||
logDebug(`Uploaded ${successCount}/${files.length} file(s) in ${elapsedMs}ms`)
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// List Files Functions (1P/Cloud mode)
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* File metadata returned from listFilesCreatedAfter
|
||||
*/
|
||||
export type FileMetadata = {
|
||||
filename: string
|
||||
fileId: string
|
||||
size: number
|
||||
}
|
||||
|
||||
/**
|
||||
* List files created after a given timestamp (1P/Cloud mode).
|
||||
* Uses the public GET /v1/files endpoint with after_created_at query param.
|
||||
* Handles pagination via after_id cursor when has_more is true.
|
||||
*
|
||||
* @param afterCreatedAt - ISO 8601 timestamp to filter files created after
|
||||
* @param config - Files API configuration
|
||||
* @returns Array of file metadata for files created after the timestamp
|
||||
*/
|
||||
export async function listFilesCreatedAfter(
|
||||
afterCreatedAt: string,
|
||||
config: FilesApiConfig,
|
||||
): Promise<FileMetadata[]> {
|
||||
const baseUrl = config.baseUrl || getDefaultApiBaseUrl()
|
||||
const headers = {
|
||||
Authorization: `Bearer ${config.oauthToken}`,
|
||||
'anthropic-version': ANTHROPIC_VERSION,
|
||||
'anthropic-beta': FILES_API_BETA_HEADER,
|
||||
}
|
||||
|
||||
logDebug(`Listing files created after ${afterCreatedAt}`)
|
||||
|
||||
const allFiles: FileMetadata[] = []
|
||||
let afterId: string | undefined
|
||||
|
||||
// Paginate through results
|
||||
while (true) {
|
||||
const params: Record<string, string> = {
|
||||
after_created_at: afterCreatedAt,
|
||||
}
|
||||
if (afterId) {
|
||||
params.after_id = afterId
|
||||
}
|
||||
|
||||
const page = await retryWithBackoff(
|
||||
`List files after ${afterCreatedAt}`,
|
||||
async () => {
|
||||
try {
|
||||
const response = await axios.get(`${baseUrl}/v1/files`, {
|
||||
headers,
|
||||
params,
|
||||
timeout: 60000,
|
||||
validateStatus: status => status < 500,
|
||||
})
|
||||
|
||||
if (response.status === 200) {
|
||||
return { done: true, value: response.data }
|
||||
}
|
||||
|
||||
if (response.status === 401) {
|
||||
logEvent('tengu_file_list_failed', {
|
||||
error_type:
|
||||
'auth' as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
throw new Error('Authentication failed: invalid or missing API key')
|
||||
}
|
||||
if (response.status === 403) {
|
||||
logEvent('tengu_file_list_failed', {
|
||||
error_type:
|
||||
'forbidden' as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
throw new Error('Access denied to list files')
|
||||
}
|
||||
|
||||
return { done: false, error: `status ${response.status}` }
|
||||
} catch (error) {
|
||||
if (!axios.isAxiosError(error)) {
|
||||
throw error
|
||||
}
|
||||
logEvent('tengu_file_list_failed', {
|
||||
error_type:
|
||||
'network' as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
return { done: false, error: error.message }
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
const files = page.data || []
|
||||
for (const f of files) {
|
||||
allFiles.push({
|
||||
filename: f.filename,
|
||||
fileId: f.id,
|
||||
size: f.size_bytes,
|
||||
})
|
||||
}
|
||||
|
||||
if (!page.has_more) {
|
||||
break
|
||||
}
|
||||
|
||||
// Use the last file's ID as cursor for next page
|
||||
const lastFile = files.at(-1)
|
||||
if (!lastFile?.id) {
|
||||
break
|
||||
}
|
||||
afterId = lastFile.id
|
||||
}
|
||||
|
||||
logDebug(`Listed ${allFiles.length} files created after ${afterCreatedAt}`)
|
||||
return allFiles
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Parse Functions
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Parse file attachment specs from CLI arguments
|
||||
* Format: <file_id>:<relative_path>
|
||||
*
|
||||
* @param fileSpecs - Array of file spec strings
|
||||
* @returns Parsed file attachments
|
||||
*/
|
||||
export function parseFileSpecs(fileSpecs: string[]): File[] {
|
||||
const files: File[] = []
|
||||
|
||||
// Sandbox-gateway may pass multiple specs as a single space-separated string
|
||||
const expandedSpecs = fileSpecs.flatMap(s => s.split(' ').filter(Boolean))
|
||||
|
||||
for (const spec of expandedSpecs) {
|
||||
const colonIndex = spec.indexOf(':')
|
||||
if (colonIndex === -1) {
|
||||
continue
|
||||
}
|
||||
|
||||
const fileId = spec.substring(0, colonIndex)
|
||||
const relativePath = spec.substring(colonIndex + 1)
|
||||
|
||||
if (!fileId || !relativePath) {
|
||||
logDebugError(
|
||||
`Invalid file spec: ${spec}. Both file_id and path are required`,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
files.push({ fileId, relativePath })
|
||||
}
|
||||
|
||||
return files
|
||||
}
|
||||
60
src/services/api/firstTokenDate.ts
Normal file
60
src/services/api/firstTokenDate.ts
Normal file
@@ -0,0 +1,60 @@
|
||||
import axios from 'axios'
|
||||
import { getOauthConfig } from '../../constants/oauth.js'
|
||||
import { getGlobalConfig, saveGlobalConfig } from '../../utils/config.js'
|
||||
import { getAuthHeaders } from '../../utils/http.js'
|
||||
import { logError } from '../../utils/log.js'
|
||||
import { getClaudeCodeUserAgent } from '../../utils/userAgent.js'
|
||||
|
||||
/**
|
||||
* Fetch the user's first Claude Code token date and store in config.
|
||||
* This is called after successful login to cache when they started using Claude Code.
|
||||
*/
|
||||
export async function fetchAndStoreClaudeCodeFirstTokenDate(): Promise<void> {
|
||||
try {
|
||||
const config = getGlobalConfig()
|
||||
|
||||
if (config.claudeCodeFirstTokenDate !== undefined) {
|
||||
return
|
||||
}
|
||||
|
||||
const authHeaders = getAuthHeaders()
|
||||
if (authHeaders.error) {
|
||||
logError(new Error(`Failed to get auth headers: ${authHeaders.error}`))
|
||||
return
|
||||
}
|
||||
|
||||
const oauthConfig = getOauthConfig()
|
||||
const url = `${oauthConfig.BASE_API_URL}/api/organization/claude_code_first_token_date`
|
||||
|
||||
const response = await axios.get(url, {
|
||||
headers: {
|
||||
...authHeaders.headers,
|
||||
'User-Agent': getClaudeCodeUserAgent(),
|
||||
},
|
||||
timeout: 10000,
|
||||
})
|
||||
|
||||
const firstTokenDate = response.data?.first_token_date ?? null
|
||||
|
||||
// Validate the date if it's not null
|
||||
if (firstTokenDate !== null) {
|
||||
const dateTime = new Date(firstTokenDate).getTime()
|
||||
if (isNaN(dateTime)) {
|
||||
logError(
|
||||
new Error(
|
||||
`Received invalid first_token_date from API: ${firstTokenDate}`,
|
||||
),
|
||||
)
|
||||
// Don't save invalid dates
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
saveGlobalConfig(current => ({
|
||||
...current,
|
||||
claudeCodeFirstTokenDate: firstTokenDate,
|
||||
}))
|
||||
} catch (error) {
|
||||
logError(error)
|
||||
}
|
||||
}
|
||||
357
src/services/api/grove.ts
Normal file
357
src/services/api/grove.ts
Normal file
@@ -0,0 +1,357 @@
|
||||
import axios from 'axios'
|
||||
import memoize from 'lodash-es/memoize.js'
|
||||
import {
|
||||
type AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
logEvent,
|
||||
} from 'src/services/analytics/index.js'
|
||||
import { getOauthAccountInfo, isConsumerSubscriber } from 'src/utils/auth.js'
|
||||
import { logForDebugging } from 'src/utils/debug.js'
|
||||
import { gracefulShutdown } from 'src/utils/gracefulShutdown.js'
|
||||
import { isEssentialTrafficOnly } from 'src/utils/privacyLevel.js'
|
||||
import { writeToStderr } from 'src/utils/process.js'
|
||||
import { getOauthConfig } from '../../constants/oauth.js'
|
||||
import { getGlobalConfig, saveGlobalConfig } from '../../utils/config.js'
|
||||
import {
|
||||
getAuthHeaders,
|
||||
getUserAgent,
|
||||
withOAuth401Retry,
|
||||
} from '../../utils/http.js'
|
||||
import { logError } from '../../utils/log.js'
|
||||
import { getClaudeCodeUserAgent } from '../../utils/userAgent.js'
|
||||
|
||||
// Cache expiration: 24 hours
|
||||
const GROVE_CACHE_EXPIRATION_MS = 24 * 60 * 60 * 1000
|
||||
|
||||
export type AccountSettings = {
|
||||
grove_enabled: boolean | null
|
||||
grove_notice_viewed_at: string | null
|
||||
}
|
||||
|
||||
export type GroveConfig = {
|
||||
grove_enabled: boolean
|
||||
domain_excluded: boolean
|
||||
notice_is_grace_period: boolean
|
||||
notice_reminder_frequency: number | null
|
||||
}
|
||||
|
||||
/**
|
||||
* Result type that distinguishes between API failure and success.
|
||||
* - success: true means API call succeeded (data may still contain null fields)
|
||||
* - success: false means API call failed after retry
|
||||
*/
|
||||
export type ApiResult<T> = { success: true; data: T } | { success: false }
|
||||
|
||||
/**
|
||||
* Get the current Grove settings for the user account.
|
||||
* Returns ApiResult to distinguish between API failure and success.
|
||||
* Uses existing OAuth 401 retry, then returns failure if that doesn't help.
|
||||
*
|
||||
* Memoized for the session to avoid redundant per-render requests.
|
||||
* Cache is invalidated in updateGroveSettings() so post-toggle reads are fresh.
|
||||
*/
|
||||
export const getGroveSettings = memoize(
|
||||
async (): Promise<ApiResult<AccountSettings>> => {
|
||||
// Grove is a notification feature; during an outage, skipping it is correct.
|
||||
if (isEssentialTrafficOnly()) {
|
||||
return { success: false }
|
||||
}
|
||||
try {
|
||||
const response = await withOAuth401Retry(() => {
|
||||
const authHeaders = getAuthHeaders()
|
||||
if (authHeaders.error) {
|
||||
throw new Error(`Failed to get auth headers: ${authHeaders.error}`)
|
||||
}
|
||||
return axios.get<AccountSettings>(
|
||||
`${getOauthConfig().BASE_API_URL}/api/oauth/account/settings`,
|
||||
{
|
||||
headers: {
|
||||
...authHeaders.headers,
|
||||
'User-Agent': getClaudeCodeUserAgent(),
|
||||
},
|
||||
},
|
||||
)
|
||||
})
|
||||
return { success: true, data: response.data }
|
||||
} catch (err) {
|
||||
logError(err)
|
||||
// Don't cache failures — transient network issues would lock the user
|
||||
// out of privacy settings for the entire session (deadlock: dialog needs
|
||||
// success to render the toggle, toggle calls updateGroveSettings which
|
||||
// is the only other place the cache is cleared).
|
||||
getGroveSettings.cache.clear?.()
|
||||
return { success: false }
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
/**
|
||||
* Mark that the Grove notice has been viewed by the user
|
||||
*/
|
||||
export async function markGroveNoticeViewed(): Promise<void> {
|
||||
try {
|
||||
await withOAuth401Retry(() => {
|
||||
const authHeaders = getAuthHeaders()
|
||||
if (authHeaders.error) {
|
||||
throw new Error(`Failed to get auth headers: ${authHeaders.error}`)
|
||||
}
|
||||
return axios.post(
|
||||
`${getOauthConfig().BASE_API_URL}/api/oauth/account/grove_notice_viewed`,
|
||||
{},
|
||||
{
|
||||
headers: {
|
||||
...authHeaders.headers,
|
||||
'User-Agent': getClaudeCodeUserAgent(),
|
||||
},
|
||||
},
|
||||
)
|
||||
})
|
||||
// This mutates grove_notice_viewed_at server-side — Grove.tsx:87 reads it
|
||||
// to decide whether to show the dialog. Without invalidation a same-session
|
||||
// remount would read stale viewed_at:null and re-show the dialog.
|
||||
getGroveSettings.cache.clear?.()
|
||||
} catch (err) {
|
||||
logError(err)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update Grove settings for the user account
|
||||
*/
|
||||
export async function updateGroveSettings(
|
||||
groveEnabled: boolean,
|
||||
): Promise<void> {
|
||||
try {
|
||||
await withOAuth401Retry(() => {
|
||||
const authHeaders = getAuthHeaders()
|
||||
if (authHeaders.error) {
|
||||
throw new Error(`Failed to get auth headers: ${authHeaders.error}`)
|
||||
}
|
||||
return axios.patch(
|
||||
`${getOauthConfig().BASE_API_URL}/api/oauth/account/settings`,
|
||||
{
|
||||
grove_enabled: groveEnabled,
|
||||
},
|
||||
{
|
||||
headers: {
|
||||
...authHeaders.headers,
|
||||
'User-Agent': getClaudeCodeUserAgent(),
|
||||
},
|
||||
},
|
||||
)
|
||||
})
|
||||
// Invalidate memoized settings so the post-toggle confirmation
|
||||
// read in privacy-settings.tsx picks up the new value.
|
||||
getGroveSettings.cache.clear?.()
|
||||
} catch (err) {
|
||||
logError(err)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if user is qualified for Grove (non-blocking, cache-first).
|
||||
*
|
||||
* This function never blocks on network - it returns cached data immediately
|
||||
* and fetches in the background if needed. On cold start (no cache), it returns
|
||||
* false and the Grove dialog won't show until the next session.
|
||||
*/
|
||||
export async function isQualifiedForGrove(): Promise<boolean> {
|
||||
if (!isConsumerSubscriber()) {
|
||||
return false
|
||||
}
|
||||
|
||||
const accountId = getOauthAccountInfo()?.accountUuid
|
||||
if (!accountId) {
|
||||
return false
|
||||
}
|
||||
|
||||
const globalConfig = getGlobalConfig()
|
||||
const cachedEntry = globalConfig.groveConfigCache?.[accountId]
|
||||
const now = Date.now()
|
||||
|
||||
// No cache - trigger background fetch and return false (non-blocking)
|
||||
// The Grove dialog won't show this session, but will next time if eligible
|
||||
if (!cachedEntry) {
|
||||
logForDebugging(
|
||||
'Grove: No cache, fetching config in background (dialog skipped this session)',
|
||||
)
|
||||
void fetchAndStoreGroveConfig(accountId)
|
||||
return false
|
||||
}
|
||||
|
||||
// Cache exists but is stale - return cached value and refresh in background
|
||||
if (now - cachedEntry.timestamp > GROVE_CACHE_EXPIRATION_MS) {
|
||||
logForDebugging(
|
||||
'Grove: Cache stale, returning cached data and refreshing in background',
|
||||
)
|
||||
void fetchAndStoreGroveConfig(accountId)
|
||||
return cachedEntry.grove_enabled
|
||||
}
|
||||
|
||||
// Cache is fresh - return it immediately
|
||||
logForDebugging('Grove: Using fresh cached config')
|
||||
return cachedEntry.grove_enabled
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch Grove config from API and store in cache
|
||||
*/
|
||||
async function fetchAndStoreGroveConfig(accountId: string): Promise<void> {
|
||||
try {
|
||||
const result = await getGroveNoticeConfig()
|
||||
if (!result.success) {
|
||||
return
|
||||
}
|
||||
const groveEnabled = result.data.grove_enabled
|
||||
const cachedEntry = getGlobalConfig().groveConfigCache?.[accountId]
|
||||
if (
|
||||
cachedEntry?.grove_enabled === groveEnabled &&
|
||||
Date.now() - cachedEntry.timestamp <= GROVE_CACHE_EXPIRATION_MS
|
||||
) {
|
||||
return
|
||||
}
|
||||
saveGlobalConfig(current => ({
|
||||
...current,
|
||||
groveConfigCache: {
|
||||
...current.groveConfigCache,
|
||||
[accountId]: {
|
||||
grove_enabled: groveEnabled,
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
},
|
||||
}))
|
||||
} catch (err) {
|
||||
logForDebugging(`Grove: Failed to fetch and store config: ${err}`)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get Grove Statsig configuration from the API.
|
||||
* Returns ApiResult to distinguish between API failure and success.
|
||||
* Uses existing OAuth 401 retry, then returns failure if that doesn't help.
|
||||
*/
|
||||
export const getGroveNoticeConfig = memoize(
|
||||
async (): Promise<ApiResult<GroveConfig>> => {
|
||||
// Grove is a notification feature; during an outage, skipping it is correct.
|
||||
if (isEssentialTrafficOnly()) {
|
||||
return { success: false }
|
||||
}
|
||||
try {
|
||||
const response = await withOAuth401Retry(() => {
|
||||
const authHeaders = getAuthHeaders()
|
||||
if (authHeaders.error) {
|
||||
throw new Error(`Failed to get auth headers: ${authHeaders.error}`)
|
||||
}
|
||||
return axios.get<GroveConfig>(
|
||||
`${getOauthConfig().BASE_API_URL}/api/claude_code_grove`,
|
||||
{
|
||||
headers: {
|
||||
...authHeaders.headers,
|
||||
'User-Agent': getUserAgent(),
|
||||
},
|
||||
timeout: 3000, // Short timeout - if slow, skip Grove dialog
|
||||
},
|
||||
)
|
||||
})
|
||||
|
||||
// Map the API response to the GroveConfig type
|
||||
const {
|
||||
grove_enabled,
|
||||
domain_excluded,
|
||||
notice_is_grace_period,
|
||||
notice_reminder_frequency,
|
||||
} = response.data
|
||||
|
||||
return {
|
||||
success: true,
|
||||
data: {
|
||||
grove_enabled,
|
||||
domain_excluded: domain_excluded ?? false,
|
||||
notice_is_grace_period: notice_is_grace_period ?? true,
|
||||
notice_reminder_frequency,
|
||||
},
|
||||
}
|
||||
} catch (err) {
|
||||
logForDebugging(`Failed to fetch Grove notice config: ${err}`)
|
||||
return { success: false }
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
/**
|
||||
* Determines whether the Grove dialog should be shown.
|
||||
* Returns false if either API call failed (after retry) - we hide the dialog on API failure.
|
||||
*/
|
||||
export function calculateShouldShowGrove(
|
||||
settingsResult: ApiResult<AccountSettings>,
|
||||
configResult: ApiResult<GroveConfig>,
|
||||
showIfAlreadyViewed: boolean,
|
||||
): boolean {
|
||||
// Hide dialog on API failure (after retry)
|
||||
if (!settingsResult.success || !configResult.success) {
|
||||
return false
|
||||
}
|
||||
|
||||
const settings = settingsResult.data
|
||||
const config = configResult.data
|
||||
|
||||
const hasChosen = settings.grove_enabled !== null
|
||||
if (hasChosen) {
|
||||
return false
|
||||
}
|
||||
if (showIfAlreadyViewed) {
|
||||
return true
|
||||
}
|
||||
if (!config.notice_is_grace_period) {
|
||||
return true
|
||||
}
|
||||
// Check if we need to remind the user to accept the terms and choose
|
||||
// whether to help improve Claude.
|
||||
const reminderFrequency = config.notice_reminder_frequency
|
||||
if (reminderFrequency !== null && settings.grove_notice_viewed_at) {
|
||||
const daysSinceViewed = Math.floor(
|
||||
(Date.now() - new Date(settings.grove_notice_viewed_at).getTime()) /
|
||||
(1000 * 60 * 60 * 24),
|
||||
)
|
||||
return daysSinceViewed >= reminderFrequency
|
||||
} else {
|
||||
// Show if never viewed before
|
||||
const viewedAt = settings.grove_notice_viewed_at
|
||||
return viewedAt === null || viewedAt === undefined
|
||||
}
|
||||
}
|
||||
|
||||
export async function checkGroveForNonInteractive(): Promise<void> {
|
||||
const [settingsResult, configResult] = await Promise.all([
|
||||
getGroveSettings(),
|
||||
getGroveNoticeConfig(),
|
||||
])
|
||||
|
||||
// Check if user hasn't made a choice yet (returns false on API failure)
|
||||
const shouldShowGrove = calculateShouldShowGrove(
|
||||
settingsResult,
|
||||
configResult,
|
||||
false,
|
||||
)
|
||||
|
||||
if (shouldShowGrove) {
|
||||
// shouldShowGrove is only true if both API calls succeeded
|
||||
const config = configResult.success ? configResult.data : null
|
||||
logEvent('tengu_grove_print_viewed', {
|
||||
dismissable:
|
||||
config?.notice_is_grace_period as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
if (config === null || config.notice_is_grace_period) {
|
||||
// Grace period is still active - show informational message and continue
|
||||
writeToStderr(
|
||||
'\nAn update to our Consumer Terms and Privacy Policy will take effect on October 8, 2025. Run `claude` to review the updated terms.\n\n',
|
||||
)
|
||||
await markGroveNoticeViewed()
|
||||
} else {
|
||||
// Grace period has ended - show error message and exit
|
||||
writeToStderr(
|
||||
'\n[ACTION REQUIRED] An update to our Consumer Terms and Privacy Policy has taken effect on October 8, 2025. You must run `claude` to review the updated terms.\n\n',
|
||||
)
|
||||
await gracefulShutdown(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
788
src/services/api/logging.ts
Normal file
788
src/services/api/logging.ts
Normal file
@@ -0,0 +1,788 @@
|
||||
import { feature } from 'bun:bundle'
|
||||
import { APIError } from '@anthropic-ai/sdk'
|
||||
import type {
|
||||
BetaStopReason,
|
||||
BetaUsage as Usage,
|
||||
} from '@anthropic-ai/sdk/resources/beta/messages/messages.mjs'
|
||||
import {
|
||||
addToTotalDurationState,
|
||||
consumePostCompaction,
|
||||
getIsNonInteractiveSession,
|
||||
getLastApiCompletionTimestamp,
|
||||
getTeleportedSessionInfo,
|
||||
markFirstTeleportMessageLogged,
|
||||
setLastApiCompletionTimestamp,
|
||||
} from 'src/bootstrap/state.js'
|
||||
import type { QueryChainTracking } from 'src/Tool.js'
|
||||
import { isConnectorTextBlock } from 'src/types/connectorText.js'
|
||||
import type { AssistantMessage } from 'src/types/message.js'
|
||||
import { logForDebugging } from 'src/utils/debug.js'
|
||||
import type { EffortLevel } from 'src/utils/effort.js'
|
||||
import { logError } from 'src/utils/log.js'
|
||||
import { getAPIProviderForStatsig } from 'src/utils/model/providers.js'
|
||||
import type { PermissionMode } from 'src/utils/permissions/PermissionMode.js'
|
||||
import { jsonStringify } from 'src/utils/slowOperations.js'
|
||||
import { logOTelEvent } from 'src/utils/telemetry/events.js'
|
||||
import {
|
||||
endLLMRequestSpan,
|
||||
isBetaTracingEnabled,
|
||||
type Span,
|
||||
} from 'src/utils/telemetry/sessionTracing.js'
|
||||
import type { NonNullableUsage } from '../../entrypoints/sdk/sdkUtilityTypes.js'
|
||||
import { consumeInvokingRequestId } from '../../utils/agentContext.js'
|
||||
import {
|
||||
type AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
logEvent,
|
||||
} from '../analytics/index.js'
|
||||
import { sanitizeToolNameForAnalytics } from '../analytics/metadata.js'
|
||||
import { EMPTY_USAGE } from './emptyUsage.js'
|
||||
import { classifyAPIError } from './errors.js'
|
||||
import { extractConnectionErrorDetails } from './errorUtils.js'
|
||||
|
||||
export type { NonNullableUsage }
|
||||
export { EMPTY_USAGE }
|
||||
|
||||
// Strategy used for global prompt caching
|
||||
export type GlobalCacheStrategy = 'tool_based' | 'system_prompt' | 'none'
|
||||
|
||||
function getErrorMessage(error: unknown): string {
|
||||
if (error instanceof APIError) {
|
||||
const body = error.error as { error?: { message?: string } } | undefined
|
||||
if (body?.error?.message) return body.error.message
|
||||
}
|
||||
return error instanceof Error ? error.message : String(error)
|
||||
}
|
||||
|
||||
type KnownGateway =
|
||||
| 'litellm'
|
||||
| 'helicone'
|
||||
| 'portkey'
|
||||
| 'cloudflare-ai-gateway'
|
||||
| 'kong'
|
||||
| 'braintrust'
|
||||
| 'databricks'
|
||||
|
||||
// Gateway fingerprints for detecting AI gateways from response headers
|
||||
const GATEWAY_FINGERPRINTS: Partial<
|
||||
Record<KnownGateway, { prefixes: string[] }>
|
||||
> = {
|
||||
// https://docs.litellm.ai/docs/proxy/response_headers
|
||||
litellm: {
|
||||
prefixes: ['x-litellm-'],
|
||||
},
|
||||
// https://docs.helicone.ai/helicone-headers/header-directory
|
||||
helicone: {
|
||||
prefixes: ['helicone-'],
|
||||
},
|
||||
// https://portkey.ai/docs/api-reference/response-schema
|
||||
portkey: {
|
||||
prefixes: ['x-portkey-'],
|
||||
},
|
||||
// https://developers.cloudflare.com/ai-gateway/evaluations/add-human-feedback-api/
|
||||
'cloudflare-ai-gateway': {
|
||||
prefixes: ['cf-aig-'],
|
||||
},
|
||||
// https://developer.konghq.com/ai-gateway/ — X-Kong-Upstream-Latency, X-Kong-Proxy-Latency
|
||||
kong: {
|
||||
prefixes: ['x-kong-'],
|
||||
},
|
||||
// https://www.braintrust.dev/docs/guides/proxy — x-bt-used-endpoint, x-bt-cached
|
||||
braintrust: {
|
||||
prefixes: ['x-bt-'],
|
||||
},
|
||||
}
|
||||
|
||||
// Gateways that use provider-owned domains (not self-hosted), so the
|
||||
// ANTHROPIC_BASE_URL hostname is a reliable signal even without a
|
||||
// distinctive response header.
|
||||
const GATEWAY_HOST_SUFFIXES: Partial<Record<KnownGateway, string[]>> = {
|
||||
// https://docs.databricks.com/aws/en/ai-gateway/
|
||||
databricks: [
|
||||
'.cloud.databricks.com',
|
||||
'.azuredatabricks.net',
|
||||
'.gcp.databricks.com',
|
||||
],
|
||||
}
|
||||
|
||||
function detectGateway({
|
||||
headers,
|
||||
baseUrl,
|
||||
}: {
|
||||
headers?: globalThis.Headers
|
||||
baseUrl?: string
|
||||
}): KnownGateway | undefined {
|
||||
if (headers) {
|
||||
// Header names are already lowercase from the Headers API
|
||||
const headerNames: string[] = []
|
||||
headers.forEach((_, key) => headerNames.push(key))
|
||||
for (const [gw, { prefixes }] of Object.entries(GATEWAY_FINGERPRINTS)) {
|
||||
if (prefixes.some(p => headerNames.some(h => h.startsWith(p)))) {
|
||||
return gw as KnownGateway
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (baseUrl) {
|
||||
try {
|
||||
const host = new URL(baseUrl).hostname.toLowerCase()
|
||||
for (const [gw, suffixes] of Object.entries(GATEWAY_HOST_SUFFIXES)) {
|
||||
if (suffixes.some(s => host.endsWith(s))) {
|
||||
return gw as KnownGateway
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// malformed URL — ignore
|
||||
}
|
||||
}
|
||||
|
||||
return undefined
|
||||
}
|
||||
|
||||
function getAnthropicEnvMetadata() {
|
||||
return {
|
||||
...(process.env.ANTHROPIC_BASE_URL
|
||||
? {
|
||||
baseUrl: process.env
|
||||
.ANTHROPIC_BASE_URL as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
}
|
||||
: {}),
|
||||
...(process.env.ANTHROPIC_MODEL
|
||||
? {
|
||||
envModel: process.env
|
||||
.ANTHROPIC_MODEL as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
}
|
||||
: {}),
|
||||
...(process.env.ANTHROPIC_SMALL_FAST_MODEL
|
||||
? {
|
||||
envSmallFastModel: process.env
|
||||
.ANTHROPIC_SMALL_FAST_MODEL as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
}
|
||||
: {}),
|
||||
}
|
||||
}
|
||||
|
||||
function getBuildAgeMinutes(): number | undefined {
|
||||
if (!MACRO.BUILD_TIME) return undefined
|
||||
const buildTime = new Date(MACRO.BUILD_TIME).getTime()
|
||||
if (isNaN(buildTime)) return undefined
|
||||
return Math.floor((Date.now() - buildTime) / 60000)
|
||||
}
|
||||
|
||||
export function logAPIQuery({
|
||||
model,
|
||||
messagesLength,
|
||||
temperature,
|
||||
betas,
|
||||
permissionMode,
|
||||
querySource,
|
||||
queryTracking,
|
||||
thinkingType,
|
||||
effortValue,
|
||||
fastMode,
|
||||
previousRequestId,
|
||||
}: {
|
||||
model: string
|
||||
messagesLength: number
|
||||
temperature: number
|
||||
betas?: string[]
|
||||
permissionMode?: PermissionMode
|
||||
querySource: string
|
||||
queryTracking?: QueryChainTracking
|
||||
thinkingType?: 'adaptive' | 'enabled' | 'disabled'
|
||||
effortValue?: EffortLevel | null
|
||||
fastMode?: boolean
|
||||
previousRequestId?: string | null
|
||||
}): void {
|
||||
logEvent('tengu_api_query', {
|
||||
model: model as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
messagesLength,
|
||||
temperature: temperature,
|
||||
provider: getAPIProviderForStatsig(),
|
||||
buildAgeMins: getBuildAgeMinutes(),
|
||||
...(betas?.length
|
||||
? {
|
||||
betas: betas.join(
|
||||
',',
|
||||
) as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
}
|
||||
: {}),
|
||||
permissionMode:
|
||||
permissionMode as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
querySource:
|
||||
querySource as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
...(queryTracking
|
||||
? {
|
||||
queryChainId:
|
||||
queryTracking.chainId as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
queryDepth: queryTracking.depth,
|
||||
}
|
||||
: {}),
|
||||
thinkingType:
|
||||
thinkingType as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
effortValue:
|
||||
effortValue as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
fastMode,
|
||||
...(previousRequestId
|
||||
? {
|
||||
previousRequestId:
|
||||
previousRequestId as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
}
|
||||
: {}),
|
||||
...getAnthropicEnvMetadata(),
|
||||
})
|
||||
}
|
||||
|
||||
export function logAPIError({
|
||||
error,
|
||||
model,
|
||||
messageCount,
|
||||
messageTokens,
|
||||
durationMs,
|
||||
durationMsIncludingRetries,
|
||||
attempt,
|
||||
requestId,
|
||||
clientRequestId,
|
||||
didFallBackToNonStreaming,
|
||||
promptCategory,
|
||||
headers,
|
||||
queryTracking,
|
||||
querySource,
|
||||
llmSpan,
|
||||
fastMode,
|
||||
previousRequestId,
|
||||
}: {
|
||||
error: unknown
|
||||
model: string
|
||||
messageCount: number
|
||||
messageTokens?: number
|
||||
durationMs: number
|
||||
durationMsIncludingRetries: number
|
||||
attempt: number
|
||||
requestId?: string | null
|
||||
/** Client-generated ID sent as x-client-request-id header (survives timeouts) */
|
||||
clientRequestId?: string
|
||||
didFallBackToNonStreaming?: boolean
|
||||
promptCategory?: string
|
||||
headers?: globalThis.Headers
|
||||
queryTracking?: QueryChainTracking
|
||||
querySource?: string
|
||||
/** The span from startLLMRequestSpan - pass this to correctly match responses to requests */
|
||||
llmSpan?: Span
|
||||
fastMode?: boolean
|
||||
previousRequestId?: string | null
|
||||
}): void {
|
||||
const gateway = detectGateway({
|
||||
headers:
|
||||
error instanceof APIError && error.headers ? error.headers : headers,
|
||||
baseUrl: process.env.ANTHROPIC_BASE_URL,
|
||||
})
|
||||
|
||||
const errStr = getErrorMessage(error)
|
||||
const status = error instanceof APIError ? String(error.status) : undefined
|
||||
const errorType = classifyAPIError(error)
|
||||
|
||||
// Log detailed connection error info to debug logs (visible via --debug)
|
||||
const connectionDetails = extractConnectionErrorDetails(error)
|
||||
if (connectionDetails) {
|
||||
const sslLabel = connectionDetails.isSSLError ? ' (SSL error)' : ''
|
||||
logForDebugging(
|
||||
`Connection error details: code=${connectionDetails.code}${sslLabel}, message=${connectionDetails.message}`,
|
||||
{ level: 'error' },
|
||||
)
|
||||
}
|
||||
|
||||
const invocation = consumeInvokingRequestId()
|
||||
|
||||
if (clientRequestId) {
|
||||
logForDebugging(
|
||||
`API error x-client-request-id=${clientRequestId} (give this to the API team for server-log lookup)`,
|
||||
{ level: 'error' },
|
||||
)
|
||||
}
|
||||
|
||||
logError(error as Error)
|
||||
logEvent('tengu_api_error', {
|
||||
model: model as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
error: errStr as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
status:
|
||||
status as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
errorType:
|
||||
errorType as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
messageCount,
|
||||
messageTokens,
|
||||
durationMs,
|
||||
durationMsIncludingRetries,
|
||||
attempt,
|
||||
provider: getAPIProviderForStatsig(),
|
||||
requestId:
|
||||
(requestId as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS) ||
|
||||
undefined,
|
||||
...(invocation
|
||||
? {
|
||||
invokingRequestId:
|
||||
invocation.invokingRequestId as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
invocationKind:
|
||||
invocation.invocationKind as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
}
|
||||
: {}),
|
||||
clientRequestId:
|
||||
(clientRequestId as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS) ||
|
||||
undefined,
|
||||
didFallBackToNonStreaming,
|
||||
...(promptCategory
|
||||
? {
|
||||
promptCategory:
|
||||
promptCategory as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
}
|
||||
: {}),
|
||||
...(gateway
|
||||
? {
|
||||
gateway:
|
||||
gateway as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
}
|
||||
: {}),
|
||||
...(queryTracking
|
||||
? {
|
||||
queryChainId:
|
||||
queryTracking.chainId as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
queryDepth: queryTracking.depth,
|
||||
}
|
||||
: {}),
|
||||
...(querySource
|
||||
? {
|
||||
querySource:
|
||||
querySource as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
}
|
||||
: {}),
|
||||
fastMode,
|
||||
...(previousRequestId
|
||||
? {
|
||||
previousRequestId:
|
||||
previousRequestId as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
}
|
||||
: {}),
|
||||
...getAnthropicEnvMetadata(),
|
||||
})
|
||||
|
||||
// Log API error event for OTLP
|
||||
void logOTelEvent('api_error', {
|
||||
model: model,
|
||||
error: errStr,
|
||||
status_code: String(status),
|
||||
duration_ms: String(durationMs),
|
||||
attempt: String(attempt),
|
||||
speed: fastMode ? 'fast' : 'normal',
|
||||
})
|
||||
|
||||
// Pass the span to correctly match responses to requests when beta tracing is enabled
|
||||
endLLMRequestSpan(llmSpan, {
|
||||
success: false,
|
||||
statusCode: status ? parseInt(status) : undefined,
|
||||
error: errStr,
|
||||
attempt,
|
||||
})
|
||||
|
||||
// Log first error for teleported sessions (reliability tracking)
|
||||
const teleportInfo = getTeleportedSessionInfo()
|
||||
if (teleportInfo?.isTeleported && !teleportInfo.hasLoggedFirstMessage) {
|
||||
logEvent('tengu_teleport_first_message_error', {
|
||||
session_id:
|
||||
teleportInfo.sessionId as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
error_type:
|
||||
errorType as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
markFirstTeleportMessageLogged()
|
||||
}
|
||||
}
|
||||
|
||||
function logAPISuccess({
|
||||
model,
|
||||
preNormalizedModel,
|
||||
messageCount,
|
||||
messageTokens,
|
||||
usage,
|
||||
durationMs,
|
||||
durationMsIncludingRetries,
|
||||
attempt,
|
||||
ttftMs,
|
||||
requestId,
|
||||
stopReason,
|
||||
costUSD,
|
||||
didFallBackToNonStreaming,
|
||||
querySource,
|
||||
gateway,
|
||||
queryTracking,
|
||||
permissionMode,
|
||||
globalCacheStrategy,
|
||||
textContentLength,
|
||||
thinkingContentLength,
|
||||
toolUseContentLengths,
|
||||
connectorTextBlockCount,
|
||||
fastMode,
|
||||
previousRequestId,
|
||||
betas,
|
||||
}: {
|
||||
model: string
|
||||
preNormalizedModel: string
|
||||
messageCount: number
|
||||
messageTokens: number
|
||||
usage: Usage
|
||||
durationMs: number
|
||||
durationMsIncludingRetries: number
|
||||
attempt: number
|
||||
ttftMs: number | null
|
||||
requestId: string | null
|
||||
stopReason: BetaStopReason | null
|
||||
costUSD: number
|
||||
didFallBackToNonStreaming: boolean
|
||||
querySource: string
|
||||
gateway?: KnownGateway
|
||||
queryTracking?: QueryChainTracking
|
||||
permissionMode?: PermissionMode
|
||||
globalCacheStrategy?: GlobalCacheStrategy
|
||||
textContentLength?: number
|
||||
thinkingContentLength?: number
|
||||
toolUseContentLengths?: Record<string, number>
|
||||
connectorTextBlockCount?: number
|
||||
fastMode?: boolean
|
||||
previousRequestId?: string | null
|
||||
betas?: string[]
|
||||
}): void {
|
||||
const isNonInteractiveSession = getIsNonInteractiveSession()
|
||||
const isPostCompaction = consumePostCompaction()
|
||||
const hasPrintFlag =
|
||||
process.argv.includes('-p') || process.argv.includes('--print')
|
||||
|
||||
const now = Date.now()
|
||||
const lastCompletion = getLastApiCompletionTimestamp()
|
||||
const timeSinceLastApiCallMs =
|
||||
lastCompletion !== null ? now - lastCompletion : undefined
|
||||
|
||||
const invocation = consumeInvokingRequestId()
|
||||
|
||||
logEvent('tengu_api_success', {
|
||||
model: model as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
...(preNormalizedModel !== model
|
||||
? {
|
||||
preNormalizedModel:
|
||||
preNormalizedModel as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
}
|
||||
: {}),
|
||||
...(betas?.length
|
||||
? {
|
||||
betas: betas.join(
|
||||
',',
|
||||
) as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
}
|
||||
: {}),
|
||||
messageCount,
|
||||
messageTokens,
|
||||
inputTokens: usage.input_tokens,
|
||||
outputTokens: usage.output_tokens,
|
||||
cachedInputTokens: usage.cache_read_input_tokens ?? 0,
|
||||
uncachedInputTokens: usage.cache_creation_input_tokens ?? 0,
|
||||
durationMs: durationMs,
|
||||
durationMsIncludingRetries: durationMsIncludingRetries,
|
||||
attempt: attempt,
|
||||
ttftMs: ttftMs ?? undefined,
|
||||
buildAgeMins: getBuildAgeMinutes(),
|
||||
provider: getAPIProviderForStatsig(),
|
||||
requestId:
|
||||
(requestId as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS) ??
|
||||
undefined,
|
||||
...(invocation
|
||||
? {
|
||||
invokingRequestId:
|
||||
invocation.invokingRequestId as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
invocationKind:
|
||||
invocation.invocationKind as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
}
|
||||
: {}),
|
||||
stop_reason:
|
||||
(stopReason as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS) ??
|
||||
undefined,
|
||||
costUSD,
|
||||
didFallBackToNonStreaming,
|
||||
isNonInteractiveSession,
|
||||
print: hasPrintFlag,
|
||||
isTTY: process.stdout.isTTY ?? false,
|
||||
querySource:
|
||||
querySource as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
...(gateway
|
||||
? {
|
||||
gateway:
|
||||
gateway as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
}
|
||||
: {}),
|
||||
...(queryTracking
|
||||
? {
|
||||
queryChainId:
|
||||
queryTracking.chainId as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
queryDepth: queryTracking.depth,
|
||||
}
|
||||
: {}),
|
||||
permissionMode:
|
||||
permissionMode as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
...(globalCacheStrategy
|
||||
? {
|
||||
globalCacheStrategy:
|
||||
globalCacheStrategy as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
}
|
||||
: {}),
|
||||
...(textContentLength !== undefined
|
||||
? ({
|
||||
textContentLength,
|
||||
} as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS)
|
||||
: {}),
|
||||
...(thinkingContentLength !== undefined
|
||||
? ({
|
||||
thinkingContentLength,
|
||||
} as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS)
|
||||
: {}),
|
||||
...(toolUseContentLengths !== undefined
|
||||
? ({
|
||||
toolUseContentLengths: jsonStringify(
|
||||
toolUseContentLengths,
|
||||
) as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
} as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS)
|
||||
: {}),
|
||||
...(connectorTextBlockCount !== undefined
|
||||
? ({
|
||||
connectorTextBlockCount,
|
||||
} as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS)
|
||||
: {}),
|
||||
fastMode,
|
||||
// Log cache_deleted_input_tokens for cache editing analysis. Casts needed
|
||||
// because the field is intentionally not on NonNullableUsage (excluded from
|
||||
// external builds). Set by updateUsage() when cache editing is active.
|
||||
...(feature('CACHED_MICROCOMPACT') &&
|
||||
((usage as unknown as { cache_deleted_input_tokens?: number })
|
||||
.cache_deleted_input_tokens ?? 0) > 0
|
||||
? {
|
||||
cacheDeletedInputTokens: (
|
||||
usage as unknown as { cache_deleted_input_tokens: number }
|
||||
).cache_deleted_input_tokens,
|
||||
}
|
||||
: {}),
|
||||
...(previousRequestId
|
||||
? {
|
||||
previousRequestId:
|
||||
previousRequestId as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
}
|
||||
: {}),
|
||||
...(isPostCompaction ? { isPostCompaction } : {}),
|
||||
...getAnthropicEnvMetadata(),
|
||||
timeSinceLastApiCallMs,
|
||||
})
|
||||
|
||||
setLastApiCompletionTimestamp(now)
|
||||
}
|
||||
|
||||
export function logAPISuccessAndDuration({
|
||||
model,
|
||||
preNormalizedModel,
|
||||
start,
|
||||
startIncludingRetries,
|
||||
ttftMs,
|
||||
usage,
|
||||
attempt,
|
||||
messageCount,
|
||||
messageTokens,
|
||||
requestId,
|
||||
stopReason,
|
||||
didFallBackToNonStreaming,
|
||||
querySource,
|
||||
headers,
|
||||
costUSD,
|
||||
queryTracking,
|
||||
permissionMode,
|
||||
newMessages,
|
||||
llmSpan,
|
||||
globalCacheStrategy,
|
||||
requestSetupMs,
|
||||
attemptStartTimes,
|
||||
fastMode,
|
||||
previousRequestId,
|
||||
betas,
|
||||
}: {
|
||||
model: string
|
||||
preNormalizedModel: string
|
||||
start: number
|
||||
startIncludingRetries: number
|
||||
ttftMs: number | null
|
||||
usage: NonNullableUsage
|
||||
attempt: number
|
||||
messageCount: number
|
||||
messageTokens: number
|
||||
requestId: string | null
|
||||
stopReason: BetaStopReason | null
|
||||
didFallBackToNonStreaming: boolean
|
||||
querySource: string
|
||||
headers?: globalThis.Headers
|
||||
costUSD: number
|
||||
queryTracking?: QueryChainTracking
|
||||
permissionMode?: PermissionMode
|
||||
/** Assistant messages from the response - used to extract model_output and thinking_output
|
||||
* when beta tracing is enabled */
|
||||
newMessages?: AssistantMessage[]
|
||||
/** The span from startLLMRequestSpan - pass this to correctly match responses to requests */
|
||||
llmSpan?: Span
|
||||
/** Strategy used for global prompt caching: 'tool_based', 'system_prompt', or 'none' */
|
||||
globalCacheStrategy?: GlobalCacheStrategy
|
||||
/** Time spent in pre-request setup before the successful attempt */
|
||||
requestSetupMs?: number
|
||||
/** Timestamps (Date.now()) of each attempt start — used for retry sub-spans in Perfetto */
|
||||
attemptStartTimes?: number[]
|
||||
fastMode?: boolean
|
||||
/** Request ID from the previous API call in this session */
|
||||
previousRequestId?: string | null
|
||||
betas?: string[]
|
||||
}): void {
|
||||
const gateway = detectGateway({
|
||||
headers,
|
||||
baseUrl: process.env.ANTHROPIC_BASE_URL,
|
||||
})
|
||||
|
||||
let textContentLength: number | undefined
|
||||
let thinkingContentLength: number | undefined
|
||||
let toolUseContentLengths: Record<string, number> | undefined
|
||||
let connectorTextBlockCount: number | undefined
|
||||
|
||||
if (newMessages) {
|
||||
let textLen = 0
|
||||
let thinkingLen = 0
|
||||
let hasToolUse = false
|
||||
const toolLengths: Record<string, number> = {}
|
||||
let connectorCount = 0
|
||||
|
||||
for (const msg of newMessages) {
|
||||
for (const block of msg.message.content) {
|
||||
if (block.type === 'text') {
|
||||
textLen += block.text.length
|
||||
} else if (feature('CONNECTOR_TEXT') && isConnectorTextBlock(block)) {
|
||||
connectorCount++
|
||||
} else if (block.type === 'thinking') {
|
||||
thinkingLen += block.thinking.length
|
||||
} else if (
|
||||
block.type === 'tool_use' ||
|
||||
block.type === 'server_tool_use' ||
|
||||
block.type === 'mcp_tool_use'
|
||||
) {
|
||||
const inputLen = jsonStringify(block.input).length
|
||||
const sanitizedName = sanitizeToolNameForAnalytics(block.name)
|
||||
toolLengths[sanitizedName] =
|
||||
(toolLengths[sanitizedName] ?? 0) + inputLen
|
||||
hasToolUse = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
textContentLength = textLen
|
||||
thinkingContentLength = thinkingLen > 0 ? thinkingLen : undefined
|
||||
toolUseContentLengths = hasToolUse ? toolLengths : undefined
|
||||
connectorTextBlockCount = connectorCount > 0 ? connectorCount : undefined
|
||||
}
|
||||
|
||||
const durationMs = Date.now() - start
|
||||
const durationMsIncludingRetries = Date.now() - startIncludingRetries
|
||||
addToTotalDurationState(durationMsIncludingRetries, durationMs)
|
||||
|
||||
logAPISuccess({
|
||||
model,
|
||||
preNormalizedModel,
|
||||
messageCount,
|
||||
messageTokens,
|
||||
usage,
|
||||
durationMs,
|
||||
durationMsIncludingRetries,
|
||||
attempt,
|
||||
ttftMs,
|
||||
requestId,
|
||||
stopReason,
|
||||
costUSD,
|
||||
didFallBackToNonStreaming,
|
||||
querySource,
|
||||
gateway,
|
||||
queryTracking,
|
||||
permissionMode,
|
||||
globalCacheStrategy,
|
||||
textContentLength,
|
||||
thinkingContentLength,
|
||||
toolUseContentLengths,
|
||||
connectorTextBlockCount,
|
||||
fastMode,
|
||||
previousRequestId,
|
||||
betas,
|
||||
})
|
||||
// Log API request event for OTLP
|
||||
void logOTelEvent('api_request', {
|
||||
model,
|
||||
input_tokens: String(usage.input_tokens),
|
||||
output_tokens: String(usage.output_tokens),
|
||||
cache_read_tokens: String(usage.cache_read_input_tokens),
|
||||
cache_creation_tokens: String(usage.cache_creation_input_tokens),
|
||||
cost_usd: String(costUSD),
|
||||
duration_ms: String(durationMs),
|
||||
speed: fastMode ? 'fast' : 'normal',
|
||||
})
|
||||
|
||||
// Extract model output, thinking output, and tool call flag when beta tracing is enabled
|
||||
let modelOutput: string | undefined
|
||||
let thinkingOutput: string | undefined
|
||||
let hasToolCall: boolean | undefined
|
||||
|
||||
if (isBetaTracingEnabled() && newMessages) {
|
||||
// Model output - visible to all users
|
||||
modelOutput =
|
||||
newMessages
|
||||
.flatMap(m =>
|
||||
m.message.content
|
||||
.filter(c => c.type === 'text')
|
||||
.map(c => (c as { type: 'text'; text: string }).text),
|
||||
)
|
||||
.join('\n') || undefined
|
||||
|
||||
// Thinking output - Ant-only (build-time gated)
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
thinkingOutput =
|
||||
newMessages
|
||||
.flatMap(m =>
|
||||
m.message.content
|
||||
.filter(c => c.type === 'thinking')
|
||||
.map(c => (c as { type: 'thinking'; thinking: string }).thinking),
|
||||
)
|
||||
.join('\n') || undefined
|
||||
}
|
||||
|
||||
// Check if any tool_use blocks were in the output
|
||||
hasToolCall = newMessages.some(m =>
|
||||
m.message.content.some(c => c.type === 'tool_use'),
|
||||
)
|
||||
}
|
||||
|
||||
// Pass the span to correctly match responses to requests when beta tracing is enabled
|
||||
endLLMRequestSpan(llmSpan, {
|
||||
success: true,
|
||||
inputTokens: usage.input_tokens,
|
||||
outputTokens: usage.output_tokens,
|
||||
cacheReadTokens: usage.cache_read_input_tokens,
|
||||
cacheCreationTokens: usage.cache_creation_input_tokens,
|
||||
attempt,
|
||||
modelOutput,
|
||||
thinkingOutput,
|
||||
hasToolCall,
|
||||
ttftMs: ttftMs ?? undefined,
|
||||
requestSetupMs,
|
||||
attemptStartTimes,
|
||||
})
|
||||
|
||||
// Log first successful message for teleported sessions (reliability tracking)
|
||||
const teleportInfo = getTeleportedSessionInfo()
|
||||
if (teleportInfo?.isTeleported && !teleportInfo.hasLoggedFirstMessage) {
|
||||
logEvent('tengu_teleport_first_message_success', {
|
||||
session_id:
|
||||
teleportInfo.sessionId as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
markFirstTeleportMessageLogged()
|
||||
}
|
||||
}
|
||||
159
src/services/api/metricsOptOut.ts
Normal file
159
src/services/api/metricsOptOut.ts
Normal file
@@ -0,0 +1,159 @@
|
||||
import axios from 'axios'
|
||||
import { hasProfileScope, isClaudeAISubscriber } from '../../utils/auth.js'
|
||||
import { getGlobalConfig, saveGlobalConfig } from '../../utils/config.js'
|
||||
import { logForDebugging } from '../../utils/debug.js'
|
||||
import { errorMessage } from '../../utils/errors.js'
|
||||
import { getAuthHeaders, withOAuth401Retry } from '../../utils/http.js'
|
||||
import { logError } from '../../utils/log.js'
|
||||
import { memoizeWithTTLAsync } from '../../utils/memoize.js'
|
||||
import { isEssentialTrafficOnly } from '../../utils/privacyLevel.js'
|
||||
import { getClaudeCodeUserAgent } from '../../utils/userAgent.js'
|
||||
|
||||
type MetricsEnabledResponse = {
|
||||
metrics_logging_enabled: boolean
|
||||
}
|
||||
|
||||
type MetricsStatus = {
|
||||
enabled: boolean
|
||||
hasError: boolean
|
||||
}
|
||||
|
||||
// In-memory TTL — dedupes calls within a single process
|
||||
const CACHE_TTL_MS = 60 * 60 * 1000
|
||||
|
||||
// Disk TTL — org settings rarely change. When disk cache is fresher than this,
|
||||
// we skip the network entirely (no background refresh). This is what collapses
|
||||
// N `claude -p` invocations into ~1 API call/day.
|
||||
const DISK_CACHE_TTL_MS = 24 * 60 * 60 * 1000
|
||||
|
||||
/**
|
||||
* Internal function to call the API and check if metrics are enabled
|
||||
* This is wrapped by memoizeWithTTLAsync to add caching behavior
|
||||
*/
|
||||
async function _fetchMetricsEnabled(): Promise<MetricsEnabledResponse> {
|
||||
const authResult = getAuthHeaders()
|
||||
if (authResult.error) {
|
||||
throw new Error(`Auth error: ${authResult.error}`)
|
||||
}
|
||||
|
||||
const headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'User-Agent': getClaudeCodeUserAgent(),
|
||||
...authResult.headers,
|
||||
}
|
||||
|
||||
const endpoint = `https://api.anthropic.com/api/claude_code/organizations/metrics_enabled`
|
||||
const response = await axios.get<MetricsEnabledResponse>(endpoint, {
|
||||
headers,
|
||||
timeout: 5000,
|
||||
})
|
||||
return response.data
|
||||
}
|
||||
|
||||
async function _checkMetricsEnabledAPI(): Promise<MetricsStatus> {
|
||||
// Incident kill switch: skip the network call when nonessential traffic is disabled.
|
||||
// Returning enabled:false sheds load at the consumer (bigqueryExporter skips
|
||||
// export). Matches the non-subscriber early-return shape below.
|
||||
if (isEssentialTrafficOnly()) {
|
||||
return { enabled: false, hasError: false }
|
||||
}
|
||||
|
||||
try {
|
||||
const data = await withOAuth401Retry(_fetchMetricsEnabled, {
|
||||
also403Revoked: true,
|
||||
})
|
||||
|
||||
logForDebugging(
|
||||
`Metrics opt-out API response: enabled=${data.metrics_logging_enabled}`,
|
||||
)
|
||||
|
||||
return {
|
||||
enabled: data.metrics_logging_enabled,
|
||||
hasError: false,
|
||||
}
|
||||
} catch (error) {
|
||||
logForDebugging(
|
||||
`Failed to check metrics opt-out status: ${errorMessage(error)}`,
|
||||
)
|
||||
logError(error)
|
||||
return { enabled: false, hasError: true }
|
||||
}
|
||||
}
|
||||
|
||||
// Create memoized version with custom error handling
|
||||
const memoizedCheckMetrics = memoizeWithTTLAsync(
|
||||
_checkMetricsEnabledAPI,
|
||||
CACHE_TTL_MS,
|
||||
)
|
||||
|
||||
/**
|
||||
* Fetch (in-memory memoized) and persist to disk on change.
|
||||
* Errors are not persisted — a transient failure should not overwrite a
|
||||
* known-good disk value.
|
||||
*/
|
||||
async function refreshMetricsStatus(): Promise<MetricsStatus> {
|
||||
const result = await memoizedCheckMetrics()
|
||||
if (result.hasError) {
|
||||
return result
|
||||
}
|
||||
|
||||
const cached = getGlobalConfig().metricsStatusCache
|
||||
const unchanged = cached !== undefined && cached.enabled === result.enabled
|
||||
// Skip write when unchanged AND timestamp still fresh — avoids config churn
|
||||
// when concurrent callers race past a stale disk entry and all try to write.
|
||||
if (unchanged && Date.now() - cached.timestamp < DISK_CACHE_TTL_MS) {
|
||||
return result
|
||||
}
|
||||
|
||||
saveGlobalConfig(current => ({
|
||||
...current,
|
||||
metricsStatusCache: {
|
||||
enabled: result.enabled,
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
}))
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if metrics are enabled for the current organization.
|
||||
*
|
||||
* Two-tier cache:
|
||||
* - Disk (24h TTL): survives process restarts. Fresh disk cache → zero network.
|
||||
* - In-memory (1h TTL): dedupes the background refresh within a process.
|
||||
*
|
||||
* The caller (bigqueryExporter) tolerates stale reads — a missed export or
|
||||
* an extra one during the 24h window is acceptable.
|
||||
*/
|
||||
export async function checkMetricsEnabled(): Promise<MetricsStatus> {
|
||||
// Service key OAuth sessions lack user:profile scope → would 403.
|
||||
// API key users (non-subscribers) fall through and use x-api-key auth.
|
||||
// This check runs before the disk read so we never persist auth-state-derived
|
||||
// answers — only real API responses go to disk. Otherwise a service-key
|
||||
// session would poison the cache for a later full-OAuth session.
|
||||
if (isClaudeAISubscriber() && !hasProfileScope()) {
|
||||
return { enabled: false, hasError: false }
|
||||
}
|
||||
|
||||
const cached = getGlobalConfig().metricsStatusCache
|
||||
if (cached) {
|
||||
if (Date.now() - cached.timestamp > DISK_CACHE_TTL_MS) {
|
||||
// saveGlobalConfig's fallback path (config.ts:731) can throw if both
|
||||
// locked and fallback writes fail — catch here so fire-and-forget
|
||||
// doesn't become an unhandled rejection.
|
||||
void refreshMetricsStatus().catch(logError)
|
||||
}
|
||||
return {
|
||||
enabled: cached.enabled,
|
||||
hasError: false,
|
||||
}
|
||||
}
|
||||
|
||||
// First-ever run on this machine: block on the network to populate disk.
|
||||
return refreshMetricsStatus()
|
||||
}
|
||||
|
||||
// Export for testing purposes only
|
||||
export const _clearMetricsEnabledCacheForTesting = (): void => {
|
||||
memoizedCheckMetrics.cache.clear()
|
||||
}
|
||||
137
src/services/api/overageCreditGrant.ts
Normal file
137
src/services/api/overageCreditGrant.ts
Normal file
@@ -0,0 +1,137 @@
|
||||
import axios from 'axios'
|
||||
import { getOauthConfig } from '../../constants/oauth.js'
|
||||
import { getOauthAccountInfo } from '../../utils/auth.js'
|
||||
import { getGlobalConfig, saveGlobalConfig } from '../../utils/config.js'
|
||||
import { logError } from '../../utils/log.js'
|
||||
import { isEssentialTrafficOnly } from '../../utils/privacyLevel.js'
|
||||
import { getOAuthHeaders, prepareApiRequest } from '../../utils/teleport/api.js'
|
||||
|
||||
export type OverageCreditGrantInfo = {
|
||||
available: boolean
|
||||
eligible: boolean
|
||||
granted: boolean
|
||||
amount_minor_units: number | null
|
||||
currency: string | null
|
||||
}
|
||||
|
||||
type CachedGrantEntry = {
|
||||
info: OverageCreditGrantInfo
|
||||
timestamp: number
|
||||
}
|
||||
|
||||
const CACHE_TTL_MS = 60 * 60 * 1000 // 1 hour
|
||||
|
||||
/**
|
||||
* Fetch the current user's overage credit grant eligibility from the backend.
|
||||
* The backend resolves tier-specific amounts and role-based claim permission,
|
||||
* so the CLI just reads the response without replicating that logic.
|
||||
*/
|
||||
async function fetchOverageCreditGrant(): Promise<OverageCreditGrantInfo | null> {
|
||||
try {
|
||||
const { accessToken, orgUUID } = await prepareApiRequest()
|
||||
const url = `${getOauthConfig().BASE_API_URL}/api/oauth/organizations/${orgUUID}/overage_credit_grant`
|
||||
const response = await axios.get<OverageCreditGrantInfo>(url, {
|
||||
headers: getOAuthHeaders(accessToken),
|
||||
})
|
||||
return response.data
|
||||
} catch (err) {
|
||||
logError(err)
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get cached grant info. Returns null if no cache or cache is stale.
|
||||
* Callers should render nothing (not block) when this returns null —
|
||||
* refreshOverageCreditGrantCache fires lazily to populate it.
|
||||
*/
|
||||
export function getCachedOverageCreditGrant(): OverageCreditGrantInfo | null {
|
||||
const orgId = getOauthAccountInfo()?.organizationUuid
|
||||
if (!orgId) return null
|
||||
const cached = getGlobalConfig().overageCreditGrantCache?.[orgId]
|
||||
if (!cached) return null
|
||||
if (Date.now() - cached.timestamp > CACHE_TTL_MS) return null
|
||||
return cached.info
|
||||
}
|
||||
|
||||
/**
|
||||
* Drop the current org's cached entry so the next read refetches.
|
||||
* Leaves other orgs' entries intact.
|
||||
*/
|
||||
export function invalidateOverageCreditGrantCache(): void {
|
||||
const orgId = getOauthAccountInfo()?.organizationUuid
|
||||
if (!orgId) return
|
||||
const cache = getGlobalConfig().overageCreditGrantCache
|
||||
if (!cache || !(orgId in cache)) return
|
||||
saveGlobalConfig(prev => {
|
||||
const next = { ...prev.overageCreditGrantCache }
|
||||
delete next[orgId]
|
||||
return { ...prev, overageCreditGrantCache: next }
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch and cache grant info. Fire-and-forget; call when an upsell surface
|
||||
* is about to render and the cache is empty.
|
||||
*/
|
||||
export async function refreshOverageCreditGrantCache(): Promise<void> {
|
||||
if (isEssentialTrafficOnly()) return
|
||||
const orgId = getOauthAccountInfo()?.organizationUuid
|
||||
if (!orgId) return
|
||||
const info = await fetchOverageCreditGrant()
|
||||
if (!info) return
|
||||
// Skip rewriting info if grant data is unchanged — avoids config write
|
||||
// amplification (inc-4552 pattern). Still refresh the timestamp so the
|
||||
// TTL-based staleness check in getCachedOverageCreditGrant doesn't keep
|
||||
// re-triggering API calls on every component mount.
|
||||
saveGlobalConfig(prev => {
|
||||
// Derive from prev (lock-fresh) rather than a pre-lock getGlobalConfig()
|
||||
// read — saveConfigWithLock re-reads config from disk under the file lock,
|
||||
// so another CLI instance may have written between any outer read and lock
|
||||
// acquire.
|
||||
const prevCached = prev.overageCreditGrantCache?.[orgId]
|
||||
const existing = prevCached?.info
|
||||
const dataUnchanged =
|
||||
existing &&
|
||||
existing.available === info.available &&
|
||||
existing.eligible === info.eligible &&
|
||||
existing.granted === info.granted &&
|
||||
existing.amount_minor_units === info.amount_minor_units &&
|
||||
existing.currency === info.currency
|
||||
// When data is unchanged and timestamp is still fresh, skip the write entirely
|
||||
if (
|
||||
dataUnchanged &&
|
||||
prevCached &&
|
||||
Date.now() - prevCached.timestamp <= CACHE_TTL_MS
|
||||
) {
|
||||
return prev
|
||||
}
|
||||
const entry: CachedGrantEntry = {
|
||||
info: dataUnchanged ? existing : info,
|
||||
timestamp: Date.now(),
|
||||
}
|
||||
return {
|
||||
...prev,
|
||||
overageCreditGrantCache: {
|
||||
...prev.overageCreditGrantCache,
|
||||
[orgId]: entry,
|
||||
},
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Format the grant amount for display. Returns null if amount isn't available
|
||||
* (not eligible, or currency we don't know how to format).
|
||||
*/
|
||||
export function formatGrantAmount(info: OverageCreditGrantInfo): string | null {
|
||||
if (info.amount_minor_units == null || !info.currency) return null
|
||||
// For now only USD; backend may expand later
|
||||
if (info.currency.toUpperCase() === 'USD') {
|
||||
const dollars = info.amount_minor_units / 100
|
||||
return Number.isInteger(dollars) ? `$${dollars}` : `$${dollars.toFixed(2)}`
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
export type { CachedGrantEntry as OverageCreditGrantCacheEntry }
|
||||
727
src/services/api/promptCacheBreakDetection.ts
Normal file
727
src/services/api/promptCacheBreakDetection.ts
Normal file
@@ -0,0 +1,727 @@
|
||||
import type { BetaToolUnion } from '@anthropic-ai/sdk/resources/beta/messages/messages.mjs'
|
||||
import type { TextBlockParam } from '@anthropic-ai/sdk/resources/index.mjs'
|
||||
import { createPatch } from 'diff'
|
||||
import { mkdir, writeFile } from 'fs/promises'
|
||||
import { join } from 'path'
|
||||
import type { AgentId } from 'src/types/ids.js'
|
||||
import type { Message } from 'src/types/message.js'
|
||||
import { logForDebugging } from 'src/utils/debug.js'
|
||||
import { djb2Hash } from 'src/utils/hash.js'
|
||||
import { logError } from 'src/utils/log.js'
|
||||
import { getClaudeTempDir } from 'src/utils/permissions/filesystem.js'
|
||||
import { jsonStringify } from 'src/utils/slowOperations.js'
|
||||
import type { QuerySource } from '../../constants/querySource.js'
|
||||
import {
|
||||
type AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
logEvent,
|
||||
} from '../analytics/index.js'
|
||||
|
||||
function getCacheBreakDiffPath(): string {
|
||||
const chars = 'abcdefghijklmnopqrstuvwxyz0123456789'
|
||||
let suffix = ''
|
||||
for (let i = 0; i < 4; i++) {
|
||||
suffix += chars[Math.floor(Math.random() * chars.length)]
|
||||
}
|
||||
return join(getClaudeTempDir(), `cache-break-${suffix}.diff`)
|
||||
}
|
||||
|
||||
type PreviousState = {
|
||||
systemHash: number
|
||||
toolsHash: number
|
||||
/** Hash of system blocks WITH cache_control intact. Catches scope/TTL flips
|
||||
* (global↔org, 1h↔5m) that stripCacheControl erases from systemHash. */
|
||||
cacheControlHash: number
|
||||
toolNames: string[]
|
||||
/** Per-tool schema hash. Diffed to name which tool's description changed
|
||||
* when toolSchemasChanged but added=removed=0 (77% of tool breaks per
|
||||
* BQ 2026-03-22). AgentTool/SkillTool embed dynamic agent/command lists. */
|
||||
perToolHashes: Record<string, number>
|
||||
systemCharCount: number
|
||||
model: string
|
||||
fastMode: boolean
|
||||
/** 'tool_based' | 'system_prompt' | 'none' — flips when MCP tools are
|
||||
* discovered/removed. */
|
||||
globalCacheStrategy: string
|
||||
/** Sorted beta header list. Diffed to show which headers were added/removed. */
|
||||
betas: string[]
|
||||
/** AFK_MODE_BETA_HEADER presence — should NOT break cache anymore
|
||||
* (sticky-on latched in claude.ts). Tracked to verify the fix. */
|
||||
autoModeActive: boolean
|
||||
/** Overage state flip — should NOT break cache anymore (eligibility is
|
||||
* latched session-stable in should1hCacheTTL). Tracked to verify the fix. */
|
||||
isUsingOverage: boolean
|
||||
/** Cache-editing beta header presence — should NOT break cache anymore
|
||||
* (sticky-on latched in claude.ts). Tracked to verify the fix. */
|
||||
cachedMCEnabled: boolean
|
||||
/** Resolved effort (env → options → model default). Goes into output_config
|
||||
* or anthropic_internal.effort_override. */
|
||||
effortValue: string
|
||||
/** Hash of getExtraBodyParams() — catches CLAUDE_CODE_EXTRA_BODY and
|
||||
* anthropic_internal changes. */
|
||||
extraBodyHash: number
|
||||
callCount: number
|
||||
pendingChanges: PendingChanges | null
|
||||
prevCacheReadTokens: number | null
|
||||
/** Set when cached microcompact sends cache_edits deletions. Cache reads
|
||||
* will legitimately drop — this is expected, not a break. */
|
||||
cacheDeletionsPending: boolean
|
||||
buildDiffableContent: () => string
|
||||
}
|
||||
|
||||
type PendingChanges = {
|
||||
systemPromptChanged: boolean
|
||||
toolSchemasChanged: boolean
|
||||
modelChanged: boolean
|
||||
fastModeChanged: boolean
|
||||
cacheControlChanged: boolean
|
||||
globalCacheStrategyChanged: boolean
|
||||
betasChanged: boolean
|
||||
autoModeChanged: boolean
|
||||
overageChanged: boolean
|
||||
cachedMCChanged: boolean
|
||||
effortChanged: boolean
|
||||
extraBodyChanged: boolean
|
||||
addedToolCount: number
|
||||
removedToolCount: number
|
||||
systemCharDelta: number
|
||||
addedTools: string[]
|
||||
removedTools: string[]
|
||||
changedToolSchemas: string[]
|
||||
previousModel: string
|
||||
newModel: string
|
||||
prevGlobalCacheStrategy: string
|
||||
newGlobalCacheStrategy: string
|
||||
addedBetas: string[]
|
||||
removedBetas: string[]
|
||||
prevEffortValue: string
|
||||
newEffortValue: string
|
||||
buildPrevDiffableContent: () => string
|
||||
}
|
||||
|
||||
const previousStateBySource = new Map<string, PreviousState>()
|
||||
|
||||
// Cap the number of tracked sources to prevent unbounded memory growth.
|
||||
// Each entry stores a ~300KB+ diffableContent string (serialized system prompt
|
||||
// + tool schemas). Without a cap, spawning many subagents (each with a unique
|
||||
// agentId key) causes the map to grow indefinitely.
|
||||
const MAX_TRACKED_SOURCES = 10
|
||||
|
||||
const TRACKED_SOURCE_PREFIXES = [
|
||||
'repl_main_thread',
|
||||
'sdk',
|
||||
'agent:custom',
|
||||
'agent:default',
|
||||
'agent:builtin',
|
||||
]
|
||||
|
||||
// Minimum absolute token drop required to trigger a cache break warning.
|
||||
// Small drops (e.g., a few thousand tokens) can happen due to normal variation
|
||||
// and aren't worth alerting on.
|
||||
const MIN_CACHE_MISS_TOKENS = 2_000
|
||||
|
||||
// Anthropic's server-side prompt cache TTL thresholds to test.
|
||||
// Cache breaks after these durations are likely due to TTL expiration
|
||||
// rather than client-side changes.
|
||||
const CACHE_TTL_5MIN_MS = 5 * 60 * 1000
|
||||
export const CACHE_TTL_1HOUR_MS = 60 * 60 * 1000
|
||||
|
||||
// Models to exclude from cache break detection (e.g., haiku has different caching behavior)
|
||||
function isExcludedModel(model: string): boolean {
|
||||
return model.includes('haiku')
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the tracking key for a querySource, or null if untracked.
|
||||
* Compact shares the same server-side cache as repl_main_thread
|
||||
* (same cacheSafeParams), so they share tracking state.
|
||||
*
|
||||
* For subagents with a tracked querySource, uses the unique agentId to
|
||||
* isolate tracking state. This prevents false positive cache break
|
||||
* notifications when multiple instances of the same agent type run
|
||||
* concurrently.
|
||||
*
|
||||
* Untracked sources (speculation, session_memory, prompt_suggestion, etc.)
|
||||
* are short-lived forked agents where cache break detection provides no
|
||||
* value — they run 1-3 turns with a fresh agentId each time, so there's
|
||||
* nothing meaningful to compare against. Their cache metrics are still
|
||||
* logged via tengu_api_success for analytics.
|
||||
*/
|
||||
function getTrackingKey(
|
||||
querySource: QuerySource,
|
||||
agentId?: AgentId,
|
||||
): string | null {
|
||||
if (querySource === 'compact') return 'repl_main_thread'
|
||||
for (const prefix of TRACKED_SOURCE_PREFIXES) {
|
||||
if (querySource.startsWith(prefix)) return agentId || querySource
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
function stripCacheControl(
|
||||
items: ReadonlyArray<Record<string, unknown>>,
|
||||
): unknown[] {
|
||||
return items.map(item => {
|
||||
if (!('cache_control' in item)) return item
|
||||
const { cache_control: _, ...rest } = item
|
||||
return rest
|
||||
})
|
||||
}
|
||||
|
||||
function computeHash(data: unknown): number {
|
||||
const str = jsonStringify(data)
|
||||
if (typeof Bun !== 'undefined') {
|
||||
const hash = Bun.hash(str)
|
||||
// Bun.hash can return bigint for large inputs; convert to number safely
|
||||
return typeof hash === 'bigint' ? Number(hash & 0xffffffffn) : hash
|
||||
}
|
||||
// Fallback for non-Bun runtimes (e.g. Node.js via npm global install)
|
||||
return djb2Hash(str)
|
||||
}
|
||||
|
||||
/** MCP tool names are user-controlled (server config) and may leak filepaths.
|
||||
* Collapse them to 'mcp'; built-in names are a fixed vocabulary. */
|
||||
function sanitizeToolName(name: string): string {
|
||||
return name.startsWith('mcp__') ? 'mcp' : name
|
||||
}
|
||||
|
||||
function computePerToolHashes(
|
||||
strippedTools: ReadonlyArray<unknown>,
|
||||
names: string[],
|
||||
): Record<string, number> {
|
||||
const hashes: Record<string, number> = {}
|
||||
for (let i = 0; i < strippedTools.length; i++) {
|
||||
hashes[names[i] ?? `__idx_${i}`] = computeHash(strippedTools[i])
|
||||
}
|
||||
return hashes
|
||||
}
|
||||
|
||||
function getSystemCharCount(system: TextBlockParam[]): number {
|
||||
let total = 0
|
||||
for (const block of system) {
|
||||
total += block.text.length
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
function buildDiffableContent(
|
||||
system: TextBlockParam[],
|
||||
tools: BetaToolUnion[],
|
||||
model: string,
|
||||
): string {
|
||||
const systemText = system.map(b => b.text).join('\n\n')
|
||||
const toolDetails = tools
|
||||
.map(t => {
|
||||
if (!('name' in t)) return 'unknown'
|
||||
const desc = 'description' in t ? t.description : ''
|
||||
const schema = 'input_schema' in t ? jsonStringify(t.input_schema) : ''
|
||||
return `${t.name}\n description: ${desc}\n input_schema: ${schema}`
|
||||
})
|
||||
.sort()
|
||||
.join('\n\n')
|
||||
return `Model: ${model}\n\n=== System Prompt ===\n\n${systemText}\n\n=== Tools (${tools.length}) ===\n\n${toolDetails}\n`
|
||||
}
|
||||
|
||||
/** Extended tracking snapshot — everything that could affect the server-side
|
||||
* cache key that we can observe from the client. All fields are optional so
|
||||
* the call site can add incrementally; undefined fields compare as stable. */
|
||||
export type PromptStateSnapshot = {
|
||||
system: TextBlockParam[]
|
||||
toolSchemas: BetaToolUnion[]
|
||||
querySource: QuerySource
|
||||
model: string
|
||||
agentId?: AgentId
|
||||
fastMode?: boolean
|
||||
globalCacheStrategy?: string
|
||||
betas?: readonly string[]
|
||||
autoModeActive?: boolean
|
||||
isUsingOverage?: boolean
|
||||
cachedMCEnabled?: boolean
|
||||
effortValue?: string | number
|
||||
extraBodyParams?: unknown
|
||||
}
|
||||
|
||||
/**
|
||||
* Phase 1 (pre-call): Record the current prompt/tool state and detect what changed.
|
||||
* Does NOT fire events — just stores pending changes for phase 2 to use.
|
||||
*/
|
||||
export function recordPromptState(snapshot: PromptStateSnapshot): void {
|
||||
try {
|
||||
const {
|
||||
system,
|
||||
toolSchemas,
|
||||
querySource,
|
||||
model,
|
||||
agentId,
|
||||
fastMode,
|
||||
globalCacheStrategy = '',
|
||||
betas = [],
|
||||
autoModeActive = false,
|
||||
isUsingOverage = false,
|
||||
cachedMCEnabled = false,
|
||||
effortValue,
|
||||
extraBodyParams,
|
||||
} = snapshot
|
||||
const key = getTrackingKey(querySource, agentId)
|
||||
if (!key) return
|
||||
|
||||
const strippedSystem = stripCacheControl(
|
||||
system as unknown as ReadonlyArray<Record<string, unknown>>,
|
||||
)
|
||||
const strippedTools = stripCacheControl(
|
||||
toolSchemas as unknown as ReadonlyArray<Record<string, unknown>>,
|
||||
)
|
||||
|
||||
const systemHash = computeHash(strippedSystem)
|
||||
const toolsHash = computeHash(strippedTools)
|
||||
// Hash the full system array INCLUDING cache_control — this catches
|
||||
// scope flips (global↔org/none) and TTL flips (1h↔5m) that the stripped
|
||||
// hash can't see because the text content is identical.
|
||||
const cacheControlHash = computeHash(
|
||||
system.map(b => ('cache_control' in b ? b.cache_control : null)),
|
||||
)
|
||||
const toolNames = toolSchemas.map(t => ('name' in t ? t.name : 'unknown'))
|
||||
// Only compute per-tool hashes when the aggregate changed — common case
|
||||
// (tools unchanged) skips N extra jsonStringify calls.
|
||||
const computeToolHashes = () =>
|
||||
computePerToolHashes(strippedTools, toolNames)
|
||||
const systemCharCount = getSystemCharCount(system)
|
||||
const lazyDiffableContent = () =>
|
||||
buildDiffableContent(system, toolSchemas, model)
|
||||
const isFastMode = fastMode ?? false
|
||||
const sortedBetas = [...betas].sort()
|
||||
const effortStr = effortValue === undefined ? '' : String(effortValue)
|
||||
const extraBodyHash =
|
||||
extraBodyParams === undefined ? 0 : computeHash(extraBodyParams)
|
||||
|
||||
const prev = previousStateBySource.get(key)
|
||||
|
||||
if (!prev) {
|
||||
// Evict oldest entries if map is at capacity
|
||||
while (previousStateBySource.size >= MAX_TRACKED_SOURCES) {
|
||||
const oldest = previousStateBySource.keys().next().value
|
||||
if (oldest !== undefined) previousStateBySource.delete(oldest)
|
||||
}
|
||||
|
||||
previousStateBySource.set(key, {
|
||||
systemHash,
|
||||
toolsHash,
|
||||
cacheControlHash,
|
||||
toolNames,
|
||||
systemCharCount,
|
||||
model,
|
||||
fastMode: isFastMode,
|
||||
globalCacheStrategy,
|
||||
betas: sortedBetas,
|
||||
autoModeActive,
|
||||
isUsingOverage,
|
||||
cachedMCEnabled,
|
||||
effortValue: effortStr,
|
||||
extraBodyHash,
|
||||
callCount: 1,
|
||||
pendingChanges: null,
|
||||
prevCacheReadTokens: null,
|
||||
cacheDeletionsPending: false,
|
||||
buildDiffableContent: lazyDiffableContent,
|
||||
perToolHashes: computeToolHashes(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
prev.callCount++
|
||||
|
||||
const systemPromptChanged = systemHash !== prev.systemHash
|
||||
const toolSchemasChanged = toolsHash !== prev.toolsHash
|
||||
const modelChanged = model !== prev.model
|
||||
const fastModeChanged = isFastMode !== prev.fastMode
|
||||
const cacheControlChanged = cacheControlHash !== prev.cacheControlHash
|
||||
const globalCacheStrategyChanged =
|
||||
globalCacheStrategy !== prev.globalCacheStrategy
|
||||
const betasChanged =
|
||||
sortedBetas.length !== prev.betas.length ||
|
||||
sortedBetas.some((b, i) => b !== prev.betas[i])
|
||||
const autoModeChanged = autoModeActive !== prev.autoModeActive
|
||||
const overageChanged = isUsingOverage !== prev.isUsingOverage
|
||||
const cachedMCChanged = cachedMCEnabled !== prev.cachedMCEnabled
|
||||
const effortChanged = effortStr !== prev.effortValue
|
||||
const extraBodyChanged = extraBodyHash !== prev.extraBodyHash
|
||||
|
||||
if (
|
||||
systemPromptChanged ||
|
||||
toolSchemasChanged ||
|
||||
modelChanged ||
|
||||
fastModeChanged ||
|
||||
cacheControlChanged ||
|
||||
globalCacheStrategyChanged ||
|
||||
betasChanged ||
|
||||
autoModeChanged ||
|
||||
overageChanged ||
|
||||
cachedMCChanged ||
|
||||
effortChanged ||
|
||||
extraBodyChanged
|
||||
) {
|
||||
const prevToolSet = new Set(prev.toolNames)
|
||||
const newToolSet = new Set(toolNames)
|
||||
const prevBetaSet = new Set(prev.betas)
|
||||
const newBetaSet = new Set(sortedBetas)
|
||||
const addedTools = toolNames.filter(n => !prevToolSet.has(n))
|
||||
const removedTools = prev.toolNames.filter(n => !newToolSet.has(n))
|
||||
const changedToolSchemas: string[] = []
|
||||
if (toolSchemasChanged) {
|
||||
const newHashes = computeToolHashes()
|
||||
for (const name of toolNames) {
|
||||
if (!prevToolSet.has(name)) continue
|
||||
if (newHashes[name] !== prev.perToolHashes[name]) {
|
||||
changedToolSchemas.push(name)
|
||||
}
|
||||
}
|
||||
prev.perToolHashes = newHashes
|
||||
}
|
||||
prev.pendingChanges = {
|
||||
systemPromptChanged,
|
||||
toolSchemasChanged,
|
||||
modelChanged,
|
||||
fastModeChanged,
|
||||
cacheControlChanged,
|
||||
globalCacheStrategyChanged,
|
||||
betasChanged,
|
||||
autoModeChanged,
|
||||
overageChanged,
|
||||
cachedMCChanged,
|
||||
effortChanged,
|
||||
extraBodyChanged,
|
||||
addedToolCount: addedTools.length,
|
||||
removedToolCount: removedTools.length,
|
||||
addedTools,
|
||||
removedTools,
|
||||
changedToolSchemas,
|
||||
systemCharDelta: systemCharCount - prev.systemCharCount,
|
||||
previousModel: prev.model,
|
||||
newModel: model,
|
||||
prevGlobalCacheStrategy: prev.globalCacheStrategy,
|
||||
newGlobalCacheStrategy: globalCacheStrategy,
|
||||
addedBetas: sortedBetas.filter(b => !prevBetaSet.has(b)),
|
||||
removedBetas: prev.betas.filter(b => !newBetaSet.has(b)),
|
||||
prevEffortValue: prev.effortValue,
|
||||
newEffortValue: effortStr,
|
||||
buildPrevDiffableContent: prev.buildDiffableContent,
|
||||
}
|
||||
} else {
|
||||
prev.pendingChanges = null
|
||||
}
|
||||
|
||||
prev.systemHash = systemHash
|
||||
prev.toolsHash = toolsHash
|
||||
prev.cacheControlHash = cacheControlHash
|
||||
prev.toolNames = toolNames
|
||||
prev.systemCharCount = systemCharCount
|
||||
prev.model = model
|
||||
prev.fastMode = isFastMode
|
||||
prev.globalCacheStrategy = globalCacheStrategy
|
||||
prev.betas = sortedBetas
|
||||
prev.autoModeActive = autoModeActive
|
||||
prev.isUsingOverage = isUsingOverage
|
||||
prev.cachedMCEnabled = cachedMCEnabled
|
||||
prev.effortValue = effortStr
|
||||
prev.extraBodyHash = extraBodyHash
|
||||
prev.buildDiffableContent = lazyDiffableContent
|
||||
} catch (e: unknown) {
|
||||
logError(e)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Phase 2 (post-call): Check the API response's cache tokens to determine
|
||||
* if a cache break actually occurred. If it did, use the pending changes
|
||||
* from phase 1 to explain why.
|
||||
*/
|
||||
export async function checkResponseForCacheBreak(
|
||||
querySource: QuerySource,
|
||||
cacheReadTokens: number,
|
||||
cacheCreationTokens: number,
|
||||
messages: Message[],
|
||||
agentId?: AgentId,
|
||||
requestId?: string | null,
|
||||
): Promise<void> {
|
||||
try {
|
||||
const key = getTrackingKey(querySource, agentId)
|
||||
if (!key) return
|
||||
|
||||
const state = previousStateBySource.get(key)
|
||||
if (!state) return
|
||||
|
||||
// Skip excluded models (e.g., haiku has different caching behavior)
|
||||
if (isExcludedModel(state.model)) return
|
||||
|
||||
const prevCacheRead = state.prevCacheReadTokens
|
||||
state.prevCacheReadTokens = cacheReadTokens
|
||||
|
||||
// Calculate time since last call for TTL detection by finding the most recent
|
||||
// assistant message timestamp in the messages array (before the current response)
|
||||
const lastAssistantMessage = messages.findLast(m => m.type === 'assistant')
|
||||
const timeSinceLastAssistantMsg = lastAssistantMessage
|
||||
? Date.now() - new Date(lastAssistantMessage.timestamp).getTime()
|
||||
: null
|
||||
|
||||
// Skip the first call — no previous value to compare against
|
||||
if (prevCacheRead === null) return
|
||||
|
||||
const changes = state.pendingChanges
|
||||
|
||||
// Cache deletions via cached microcompact intentionally reduce the cached
|
||||
// prefix. The drop in cache read tokens is expected — reset the baseline
|
||||
// so we don't false-positive on the next call.
|
||||
if (state.cacheDeletionsPending) {
|
||||
state.cacheDeletionsPending = false
|
||||
logForDebugging(
|
||||
`[PROMPT CACHE] cache deletion applied, cache read: ${prevCacheRead} → ${cacheReadTokens} (expected drop)`,
|
||||
)
|
||||
// Don't flag as a break — the remaining state is still valid
|
||||
state.pendingChanges = null
|
||||
return
|
||||
}
|
||||
|
||||
// Detect a cache break: cache read dropped >5% from previous AND
|
||||
// the absolute drop exceeds the minimum threshold.
|
||||
const tokenDrop = prevCacheRead - cacheReadTokens
|
||||
if (
|
||||
cacheReadTokens >= prevCacheRead * 0.95 ||
|
||||
tokenDrop < MIN_CACHE_MISS_TOKENS
|
||||
) {
|
||||
state.pendingChanges = null
|
||||
return
|
||||
}
|
||||
|
||||
// Build explanation from pending changes (if any)
|
||||
const parts: string[] = []
|
||||
if (changes) {
|
||||
if (changes.modelChanged) {
|
||||
parts.push(
|
||||
`model changed (${changes.previousModel} → ${changes.newModel})`,
|
||||
)
|
||||
}
|
||||
if (changes.systemPromptChanged) {
|
||||
const charDelta = changes.systemCharDelta
|
||||
const charInfo =
|
||||
charDelta === 0
|
||||
? ''
|
||||
: charDelta > 0
|
||||
? ` (+${charDelta} chars)`
|
||||
: ` (${charDelta} chars)`
|
||||
parts.push(`system prompt changed${charInfo}`)
|
||||
}
|
||||
if (changes.toolSchemasChanged) {
|
||||
const toolDiff =
|
||||
changes.addedToolCount > 0 || changes.removedToolCount > 0
|
||||
? ` (+${changes.addedToolCount}/-${changes.removedToolCount} tools)`
|
||||
: ' (tool prompt/schema changed, same tool set)'
|
||||
parts.push(`tools changed${toolDiff}`)
|
||||
}
|
||||
if (changes.fastModeChanged) {
|
||||
parts.push('fast mode toggled')
|
||||
}
|
||||
if (changes.globalCacheStrategyChanged) {
|
||||
parts.push(
|
||||
`global cache strategy changed (${changes.prevGlobalCacheStrategy || 'none'} → ${changes.newGlobalCacheStrategy || 'none'})`,
|
||||
)
|
||||
}
|
||||
if (
|
||||
changes.cacheControlChanged &&
|
||||
!changes.globalCacheStrategyChanged &&
|
||||
!changes.systemPromptChanged
|
||||
) {
|
||||
// Only report as standalone cause if nothing else explains it —
|
||||
// otherwise the scope/TTL flip is a consequence, not the root cause.
|
||||
parts.push('cache_control changed (scope or TTL)')
|
||||
}
|
||||
if (changes.betasChanged) {
|
||||
const added = changes.addedBetas.length
|
||||
? `+${changes.addedBetas.join(',')}`
|
||||
: ''
|
||||
const removed = changes.removedBetas.length
|
||||
? `-${changes.removedBetas.join(',')}`
|
||||
: ''
|
||||
const diff = [added, removed].filter(Boolean).join(' ')
|
||||
parts.push(`betas changed${diff ? ` (${diff})` : ''}`)
|
||||
}
|
||||
if (changes.autoModeChanged) {
|
||||
parts.push('auto mode toggled')
|
||||
}
|
||||
if (changes.overageChanged) {
|
||||
parts.push('overage state changed (TTL latched, no flip)')
|
||||
}
|
||||
if (changes.cachedMCChanged) {
|
||||
parts.push('cached microcompact toggled')
|
||||
}
|
||||
if (changes.effortChanged) {
|
||||
parts.push(
|
||||
`effort changed (${changes.prevEffortValue || 'default'} → ${changes.newEffortValue || 'default'})`,
|
||||
)
|
||||
}
|
||||
if (changes.extraBodyChanged) {
|
||||
parts.push('extra body params changed')
|
||||
}
|
||||
}
|
||||
|
||||
// Check if time gap suggests TTL expiration
|
||||
const lastAssistantMsgOver5minAgo =
|
||||
timeSinceLastAssistantMsg !== null &&
|
||||
timeSinceLastAssistantMsg > CACHE_TTL_5MIN_MS
|
||||
const lastAssistantMsgOver1hAgo =
|
||||
timeSinceLastAssistantMsg !== null &&
|
||||
timeSinceLastAssistantMsg > CACHE_TTL_1HOUR_MS
|
||||
|
||||
// Post PR #19823 BQ analysis (bq-queries/prompt-caching/cache_break_pr19823_analysis.sql):
|
||||
// when all client-side flags are false and the gap is under TTL, ~90% of breaks
|
||||
// are server-side routing/eviction or billed/inference disagreement. Label
|
||||
// accordingly instead of implying a CC bug hunt.
|
||||
let reason: string
|
||||
if (parts.length > 0) {
|
||||
reason = parts.join(', ')
|
||||
} else if (lastAssistantMsgOver1hAgo) {
|
||||
reason = 'possible 1h TTL expiry (prompt unchanged)'
|
||||
} else if (lastAssistantMsgOver5minAgo) {
|
||||
reason = 'possible 5min TTL expiry (prompt unchanged)'
|
||||
} else if (timeSinceLastAssistantMsg !== null) {
|
||||
reason = 'likely server-side (prompt unchanged, <5min gap)'
|
||||
} else {
|
||||
reason = 'unknown cause'
|
||||
}
|
||||
|
||||
logEvent('tengu_prompt_cache_break', {
|
||||
systemPromptChanged: changes?.systemPromptChanged ?? false,
|
||||
toolSchemasChanged: changes?.toolSchemasChanged ?? false,
|
||||
modelChanged: changes?.modelChanged ?? false,
|
||||
fastModeChanged: changes?.fastModeChanged ?? false,
|
||||
cacheControlChanged: changes?.cacheControlChanged ?? false,
|
||||
globalCacheStrategyChanged: changes?.globalCacheStrategyChanged ?? false,
|
||||
betasChanged: changes?.betasChanged ?? false,
|
||||
autoModeChanged: changes?.autoModeChanged ?? false,
|
||||
overageChanged: changes?.overageChanged ?? false,
|
||||
cachedMCChanged: changes?.cachedMCChanged ?? false,
|
||||
effortChanged: changes?.effortChanged ?? false,
|
||||
extraBodyChanged: changes?.extraBodyChanged ?? false,
|
||||
addedToolCount: changes?.addedToolCount ?? 0,
|
||||
removedToolCount: changes?.removedToolCount ?? 0,
|
||||
systemCharDelta: changes?.systemCharDelta ?? 0,
|
||||
// Tool names are sanitized: built-in names are a fixed vocabulary,
|
||||
// MCP tools collapse to 'mcp' (user-configured, could leak paths).
|
||||
addedTools: (changes?.addedTools ?? [])
|
||||
.map(sanitizeToolName)
|
||||
.join(
|
||||
',',
|
||||
) as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
removedTools: (changes?.removedTools ?? [])
|
||||
.map(sanitizeToolName)
|
||||
.join(
|
||||
',',
|
||||
) as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
changedToolSchemas: (changes?.changedToolSchemas ?? [])
|
||||
.map(sanitizeToolName)
|
||||
.join(
|
||||
',',
|
||||
) as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
// Beta header names and cache strategy are fixed enum-like values,
|
||||
// not code or filepaths. requestId is an opaque server-generated ID.
|
||||
addedBetas: (changes?.addedBetas ?? []).join(
|
||||
',',
|
||||
) as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
removedBetas: (changes?.removedBetas ?? []).join(
|
||||
',',
|
||||
) as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
prevGlobalCacheStrategy: (changes?.prevGlobalCacheStrategy ??
|
||||
'') as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
newGlobalCacheStrategy: (changes?.newGlobalCacheStrategy ??
|
||||
'') as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
callNumber: state.callCount,
|
||||
prevCacheReadTokens: prevCacheRead,
|
||||
cacheReadTokens,
|
||||
cacheCreationTokens,
|
||||
timeSinceLastAssistantMsg: timeSinceLastAssistantMsg ?? -1,
|
||||
lastAssistantMsgOver5minAgo,
|
||||
lastAssistantMsgOver1hAgo,
|
||||
requestId: (requestId ??
|
||||
'') as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
|
||||
// Write diff file for ant debugging via --debug. The path is included in
|
||||
// the summary log so ants can find it (DevBar UI removed — event data
|
||||
// flows reliably to BQ for analytics).
|
||||
let diffPath: string | undefined
|
||||
if (changes?.buildPrevDiffableContent) {
|
||||
diffPath = await writeCacheBreakDiff(
|
||||
changes.buildPrevDiffableContent(),
|
||||
state.buildDiffableContent(),
|
||||
)
|
||||
}
|
||||
|
||||
const diffSuffix = diffPath ? `, diff: ${diffPath}` : ''
|
||||
const summary = `[PROMPT CACHE BREAK] ${reason} [source=${querySource}, call #${state.callCount}, cache read: ${prevCacheRead} → ${cacheReadTokens}, creation: ${cacheCreationTokens}${diffSuffix}]`
|
||||
|
||||
logForDebugging(summary, { level: 'warn' })
|
||||
|
||||
state.pendingChanges = null
|
||||
} catch (e: unknown) {
|
||||
logError(e)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Call when cached microcompact sends cache_edits deletions.
|
||||
* The next API response will have lower cache read tokens — that's
|
||||
* expected, not a cache break.
|
||||
*/
|
||||
export function notifyCacheDeletion(
|
||||
querySource: QuerySource,
|
||||
agentId?: AgentId,
|
||||
): void {
|
||||
const key = getTrackingKey(querySource, agentId)
|
||||
const state = key ? previousStateBySource.get(key) : undefined
|
||||
if (state) {
|
||||
state.cacheDeletionsPending = true
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Call after compaction to reset the cache read baseline.
|
||||
* Compaction legitimately reduces message count, so cache read tokens
|
||||
* will naturally drop on the next call — that's not a break.
|
||||
*/
|
||||
export function notifyCompaction(
|
||||
querySource: QuerySource,
|
||||
agentId?: AgentId,
|
||||
): void {
|
||||
const key = getTrackingKey(querySource, agentId)
|
||||
const state = key ? previousStateBySource.get(key) : undefined
|
||||
if (state) {
|
||||
state.prevCacheReadTokens = null
|
||||
}
|
||||
}
|
||||
|
||||
export function cleanupAgentTracking(agentId: AgentId): void {
|
||||
previousStateBySource.delete(agentId)
|
||||
}
|
||||
|
||||
export function resetPromptCacheBreakDetection(): void {
|
||||
previousStateBySource.clear()
|
||||
}
|
||||
|
||||
async function writeCacheBreakDiff(
|
||||
prevContent: string,
|
||||
newContent: string,
|
||||
): Promise<string | undefined> {
|
||||
try {
|
||||
const diffPath = getCacheBreakDiffPath()
|
||||
await mkdir(getClaudeTempDir(), { recursive: true })
|
||||
const patch = createPatch(
|
||||
'prompt-state',
|
||||
prevContent,
|
||||
newContent,
|
||||
'before',
|
||||
'after',
|
||||
)
|
||||
await writeFile(diffPath, patch)
|
||||
return diffPath
|
||||
} catch {
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
281
src/services/api/referral.ts
Normal file
281
src/services/api/referral.ts
Normal file
@@ -0,0 +1,281 @@
|
||||
import axios from 'axios'
|
||||
import { getOauthConfig } from '../../constants/oauth.js'
|
||||
import {
|
||||
getOauthAccountInfo,
|
||||
getSubscriptionType,
|
||||
isClaudeAISubscriber,
|
||||
} from '../../utils/auth.js'
|
||||
import { getGlobalConfig, saveGlobalConfig } from '../../utils/config.js'
|
||||
import { logForDebugging } from '../../utils/debug.js'
|
||||
import { logError } from '../../utils/log.js'
|
||||
import { isEssentialTrafficOnly } from '../../utils/privacyLevel.js'
|
||||
import { getOAuthHeaders, prepareApiRequest } from '../../utils/teleport/api.js'
|
||||
import type {
|
||||
ReferralCampaign,
|
||||
ReferralEligibilityResponse,
|
||||
ReferralRedemptionsResponse,
|
||||
ReferrerRewardInfo,
|
||||
} from '../oauth/types.js'
|
||||
|
||||
// Cache expiration time: 24 hours (eligibility changes only on subscription/experiment changes)
|
||||
const CACHE_EXPIRATION_MS = 24 * 60 * 60 * 1000
|
||||
|
||||
// Track in-flight fetch to prevent duplicate API calls
|
||||
let fetchInProgress: Promise<ReferralEligibilityResponse | null> | null = null
|
||||
|
||||
export async function fetchReferralEligibility(
|
||||
campaign: ReferralCampaign = 'claude_code_guest_pass',
|
||||
): Promise<ReferralEligibilityResponse> {
|
||||
const { accessToken, orgUUID } = await prepareApiRequest()
|
||||
|
||||
const headers = {
|
||||
...getOAuthHeaders(accessToken),
|
||||
'x-organization-uuid': orgUUID,
|
||||
}
|
||||
|
||||
const url = `${getOauthConfig().BASE_API_URL}/api/oauth/organizations/${orgUUID}/referral/eligibility`
|
||||
|
||||
const response = await axios.get(url, {
|
||||
headers,
|
||||
params: { campaign },
|
||||
timeout: 5000, // 5 second timeout for background fetch
|
||||
})
|
||||
|
||||
return response.data
|
||||
}
|
||||
|
||||
export async function fetchReferralRedemptions(
|
||||
campaign: string = 'claude_code_guest_pass',
|
||||
): Promise<ReferralRedemptionsResponse> {
|
||||
const { accessToken, orgUUID } = await prepareApiRequest()
|
||||
|
||||
const headers = {
|
||||
...getOAuthHeaders(accessToken),
|
||||
'x-organization-uuid': orgUUID,
|
||||
}
|
||||
|
||||
const url = `${getOauthConfig().BASE_API_URL}/api/oauth/organizations/${orgUUID}/referral/redemptions`
|
||||
|
||||
const response = await axios.get<ReferralRedemptionsResponse>(url, {
|
||||
headers,
|
||||
params: { campaign },
|
||||
timeout: 10000, // 10 second timeout
|
||||
})
|
||||
|
||||
return response.data
|
||||
}
|
||||
|
||||
/**
|
||||
* Prechecks for if user can access guest passes feature
|
||||
*/
|
||||
function shouldCheckForPasses(): boolean {
|
||||
return !!(
|
||||
getOauthAccountInfo()?.organizationUuid &&
|
||||
isClaudeAISubscriber() &&
|
||||
getSubscriptionType() === 'max'
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Check cached passes eligibility from GlobalConfig
|
||||
* Returns current cached state and cache status
|
||||
*/
|
||||
export function checkCachedPassesEligibility(): {
|
||||
eligible: boolean
|
||||
needsRefresh: boolean
|
||||
hasCache: boolean
|
||||
} {
|
||||
if (!shouldCheckForPasses()) {
|
||||
return {
|
||||
eligible: false,
|
||||
needsRefresh: false,
|
||||
hasCache: false,
|
||||
}
|
||||
}
|
||||
|
||||
const orgId = getOauthAccountInfo()?.organizationUuid
|
||||
if (!orgId) {
|
||||
return {
|
||||
eligible: false,
|
||||
needsRefresh: false,
|
||||
hasCache: false,
|
||||
}
|
||||
}
|
||||
|
||||
const config = getGlobalConfig()
|
||||
const cachedEntry = config.passesEligibilityCache?.[orgId]
|
||||
|
||||
if (!cachedEntry) {
|
||||
// No cached entry, needs fetch
|
||||
return {
|
||||
eligible: false,
|
||||
needsRefresh: true,
|
||||
hasCache: false,
|
||||
}
|
||||
}
|
||||
|
||||
const { eligible, timestamp } = cachedEntry
|
||||
const now = Date.now()
|
||||
const needsRefresh = now - timestamp > CACHE_EXPIRATION_MS
|
||||
|
||||
return {
|
||||
eligible,
|
||||
needsRefresh,
|
||||
hasCache: true,
|
||||
}
|
||||
}
|
||||
|
||||
const CURRENCY_SYMBOLS: Record<string, string> = {
|
||||
USD: '$',
|
||||
EUR: '€',
|
||||
GBP: '£',
|
||||
BRL: 'R$',
|
||||
CAD: 'CA$',
|
||||
AUD: 'A$',
|
||||
NZD: 'NZ$',
|
||||
SGD: 'S$',
|
||||
}
|
||||
|
||||
export function formatCreditAmount(reward: ReferrerRewardInfo): string {
|
||||
const symbol = CURRENCY_SYMBOLS[reward.currency] ?? `${reward.currency} `
|
||||
const amount = reward.amount_minor_units / 100
|
||||
const formatted = amount % 1 === 0 ? amount.toString() : amount.toFixed(2)
|
||||
return `${symbol}${formatted}`
|
||||
}
|
||||
|
||||
/**
|
||||
* Get cached referrer reward info from eligibility cache
|
||||
* Returns the reward info if the user is in a v1 campaign, null otherwise
|
||||
*/
|
||||
export function getCachedReferrerReward(): ReferrerRewardInfo | null {
|
||||
const orgId = getOauthAccountInfo()?.organizationUuid
|
||||
if (!orgId) return null
|
||||
const config = getGlobalConfig()
|
||||
const cachedEntry = config.passesEligibilityCache?.[orgId]
|
||||
return cachedEntry?.referrer_reward ?? null
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the cached remaining passes count from eligibility cache
|
||||
* Returns the number of remaining passes, or null if not available
|
||||
*/
|
||||
export function getCachedRemainingPasses(): number | null {
|
||||
const orgId = getOauthAccountInfo()?.organizationUuid
|
||||
if (!orgId) return null
|
||||
const config = getGlobalConfig()
|
||||
const cachedEntry = config.passesEligibilityCache?.[orgId]
|
||||
return cachedEntry?.remaining_passes ?? null
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch passes eligibility and store in GlobalConfig
|
||||
* Returns the fetched response or null on error
|
||||
*/
|
||||
export async function fetchAndStorePassesEligibility(): Promise<ReferralEligibilityResponse | null> {
|
||||
// Return existing promise if fetch is already in progress
|
||||
if (fetchInProgress) {
|
||||
logForDebugging('Passes: Reusing in-flight eligibility fetch')
|
||||
return fetchInProgress
|
||||
}
|
||||
|
||||
const orgId = getOauthAccountInfo()?.organizationUuid
|
||||
|
||||
if (!orgId) {
|
||||
return null
|
||||
}
|
||||
|
||||
// Store the promise to share with concurrent calls
|
||||
fetchInProgress = (async () => {
|
||||
try {
|
||||
const response = await fetchReferralEligibility()
|
||||
|
||||
const cacheEntry = {
|
||||
...response,
|
||||
timestamp: Date.now(),
|
||||
}
|
||||
|
||||
saveGlobalConfig(current => ({
|
||||
...current,
|
||||
passesEligibilityCache: {
|
||||
...current.passesEligibilityCache,
|
||||
[orgId]: cacheEntry,
|
||||
},
|
||||
}))
|
||||
|
||||
logForDebugging(
|
||||
`Passes eligibility cached for org ${orgId}: ${response.eligible}`,
|
||||
)
|
||||
|
||||
return response
|
||||
} catch (error) {
|
||||
logForDebugging('Failed to fetch and cache passes eligibility')
|
||||
logError(error as Error)
|
||||
return null
|
||||
} finally {
|
||||
// Clear the promise when done
|
||||
fetchInProgress = null
|
||||
}
|
||||
})()
|
||||
|
||||
return fetchInProgress
|
||||
}
|
||||
|
||||
/**
|
||||
* Get cached passes eligibility data or fetch if needed
|
||||
* Main entry point for all eligibility checks
|
||||
*
|
||||
* This function never blocks on network - it returns cached data immediately
|
||||
* and fetches in the background if needed. On cold start (no cache), it returns
|
||||
* null and the passes command won't be available until the next session.
|
||||
*/
|
||||
export async function getCachedOrFetchPassesEligibility(): Promise<ReferralEligibilityResponse | null> {
|
||||
if (!shouldCheckForPasses()) {
|
||||
return null
|
||||
}
|
||||
|
||||
const orgId = getOauthAccountInfo()?.organizationUuid
|
||||
if (!orgId) {
|
||||
return null
|
||||
}
|
||||
|
||||
const config = getGlobalConfig()
|
||||
const cachedEntry = config.passesEligibilityCache?.[orgId]
|
||||
const now = Date.now()
|
||||
|
||||
// No cache - trigger background fetch and return null (non-blocking)
|
||||
// The passes command won't be available this session, but will be next time
|
||||
if (!cachedEntry) {
|
||||
logForDebugging(
|
||||
'Passes: No cache, fetching eligibility in background (command unavailable this session)',
|
||||
)
|
||||
void fetchAndStorePassesEligibility()
|
||||
return null
|
||||
}
|
||||
|
||||
// Cache exists but is stale - return stale cache and trigger background refresh
|
||||
if (now - cachedEntry.timestamp > CACHE_EXPIRATION_MS) {
|
||||
logForDebugging(
|
||||
'Passes: Cache stale, returning cached data and refreshing in background',
|
||||
)
|
||||
void fetchAndStorePassesEligibility() // Background refresh
|
||||
const { timestamp, ...response } = cachedEntry
|
||||
return response as ReferralEligibilityResponse
|
||||
}
|
||||
|
||||
// Cache is fresh - return it immediately
|
||||
logForDebugging('Passes: Using fresh cached eligibility data')
|
||||
const { timestamp, ...response } = cachedEntry
|
||||
return response as ReferralEligibilityResponse
|
||||
}
|
||||
|
||||
/**
|
||||
* Prefetch passes eligibility on startup
|
||||
*/
|
||||
export async function prefetchPassesEligibility(): Promise<void> {
|
||||
// Skip network requests if nonessential traffic is disabled
|
||||
if (isEssentialTrafficOnly()) {
|
||||
return
|
||||
}
|
||||
|
||||
void getCachedOrFetchPassesEligibility()
|
||||
}
|
||||
514
src/services/api/sessionIngress.ts
Normal file
514
src/services/api/sessionIngress.ts
Normal file
@@ -0,0 +1,514 @@
|
||||
import axios, { type AxiosError } from 'axios'
|
||||
import type { UUID } from 'crypto'
|
||||
import { getOauthConfig } from '../../constants/oauth.js'
|
||||
import type { Entry, TranscriptMessage } from '../../types/logs.js'
|
||||
import { logForDebugging } from '../../utils/debug.js'
|
||||
import { logForDiagnosticsNoPII } from '../../utils/diagLogs.js'
|
||||
import { isEnvTruthy } from '../../utils/envUtils.js'
|
||||
import { logError } from '../../utils/log.js'
|
||||
import { sequential } from '../../utils/sequential.js'
|
||||
import { getSessionIngressAuthToken } from '../../utils/sessionIngressAuth.js'
|
||||
import { sleep } from '../../utils/sleep.js'
|
||||
import { jsonStringify } from '../../utils/slowOperations.js'
|
||||
import { getOAuthHeaders } from '../../utils/teleport/api.js'
|
||||
|
||||
interface SessionIngressError {
|
||||
error?: {
|
||||
message?: string
|
||||
type?: string
|
||||
}
|
||||
}
|
||||
|
||||
// Module-level state
|
||||
const lastUuidMap: Map<string, UUID> = new Map()
|
||||
|
||||
const MAX_RETRIES = 10
|
||||
const BASE_DELAY_MS = 500
|
||||
|
||||
// Per-session sequential wrappers to prevent concurrent log writes
|
||||
const sequentialAppendBySession: Map<
|
||||
string,
|
||||
(
|
||||
entry: TranscriptMessage,
|
||||
url: string,
|
||||
headers: Record<string, string>,
|
||||
) => Promise<boolean>
|
||||
> = new Map()
|
||||
|
||||
/**
|
||||
* Gets or creates a sequential wrapper for a session
|
||||
* This ensures that log appends for a session are processed one at a time
|
||||
*/
|
||||
function getOrCreateSequentialAppend(sessionId: string) {
|
||||
let sequentialAppend = sequentialAppendBySession.get(sessionId)
|
||||
if (!sequentialAppend) {
|
||||
sequentialAppend = sequential(
|
||||
async (
|
||||
entry: TranscriptMessage,
|
||||
url: string,
|
||||
headers: Record<string, string>,
|
||||
) => await appendSessionLogImpl(sessionId, entry, url, headers),
|
||||
)
|
||||
sequentialAppendBySession.set(sessionId, sequentialAppend)
|
||||
}
|
||||
return sequentialAppend
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal implementation of appendSessionLog with retry logic
|
||||
* Retries on transient errors (network, 5xx, 429). On 409, adopts the server's
|
||||
* last UUID and retries (handles stale state from killed process's in-flight
|
||||
* requests). Fails immediately on 401.
|
||||
*/
|
||||
async function appendSessionLogImpl(
|
||||
sessionId: string,
|
||||
entry: TranscriptMessage,
|
||||
url: string,
|
||||
headers: Record<string, string>,
|
||||
): Promise<boolean> {
|
||||
for (let attempt = 1; attempt <= MAX_RETRIES; attempt++) {
|
||||
try {
|
||||
const lastUuid = lastUuidMap.get(sessionId)
|
||||
const requestHeaders = { ...headers }
|
||||
if (lastUuid) {
|
||||
requestHeaders['Last-Uuid'] = lastUuid
|
||||
}
|
||||
|
||||
const response = await axios.put(url, entry, {
|
||||
headers: requestHeaders,
|
||||
validateStatus: status => status < 500,
|
||||
})
|
||||
|
||||
if (response.status === 200 || response.status === 201) {
|
||||
lastUuidMap.set(sessionId, entry.uuid)
|
||||
logForDebugging(
|
||||
`Successfully persisted session log entry for session ${sessionId}`,
|
||||
)
|
||||
return true
|
||||
}
|
||||
|
||||
if (response.status === 409) {
|
||||
// Check if our entry was actually stored (server returned 409 but entry exists)
|
||||
// This handles the scenario where entry was stored but client received an error
|
||||
// response, causing lastUuidMap to be stale
|
||||
const serverLastUuid = response.headers['x-last-uuid']
|
||||
if (serverLastUuid === entry.uuid) {
|
||||
// Our entry IS the last entry on server - it was stored successfully previously
|
||||
lastUuidMap.set(sessionId, entry.uuid)
|
||||
logForDebugging(
|
||||
`Session entry ${entry.uuid} already present on server, recovering from stale state`,
|
||||
)
|
||||
logForDiagnosticsNoPII('info', 'session_persist_recovered_from_409')
|
||||
return true
|
||||
}
|
||||
|
||||
// Another writer (e.g. in-flight request from a killed process)
|
||||
// advanced the server's chain. Try to adopt the server's last UUID
|
||||
// from the response header, or re-fetch the session to discover it.
|
||||
if (serverLastUuid) {
|
||||
lastUuidMap.set(sessionId, serverLastUuid as UUID)
|
||||
logForDebugging(
|
||||
`Session 409: adopting server lastUuid=${serverLastUuid} from header, retrying entry ${entry.uuid}`,
|
||||
)
|
||||
} else {
|
||||
// Server didn't return x-last-uuid (e.g. v1 endpoint). Re-fetch
|
||||
// the session to discover the current head of the append chain.
|
||||
const logs = await fetchSessionLogsFromUrl(sessionId, url, headers)
|
||||
const adoptedUuid = findLastUuid(logs)
|
||||
if (adoptedUuid) {
|
||||
lastUuidMap.set(sessionId, adoptedUuid)
|
||||
logForDebugging(
|
||||
`Session 409: re-fetched ${logs!.length} entries, adopting lastUuid=${adoptedUuid}, retrying entry ${entry.uuid}`,
|
||||
)
|
||||
} else {
|
||||
// Can't determine server state — give up
|
||||
const errorData = response.data as SessionIngressError
|
||||
const errorMessage =
|
||||
errorData.error?.message || 'Concurrent modification detected'
|
||||
logError(
|
||||
new Error(
|
||||
`Session persistence conflict: UUID mismatch for session ${sessionId}, entry ${entry.uuid}. ${errorMessage}`,
|
||||
),
|
||||
)
|
||||
logForDiagnosticsNoPII(
|
||||
'error',
|
||||
'session_persist_fail_concurrent_modification',
|
||||
)
|
||||
return false
|
||||
}
|
||||
}
|
||||
logForDiagnosticsNoPII('info', 'session_persist_409_adopt_server_uuid')
|
||||
continue // retry with updated lastUuid
|
||||
}
|
||||
|
||||
if (response.status === 401) {
|
||||
logForDebugging('Session token expired or invalid')
|
||||
logForDiagnosticsNoPII('error', 'session_persist_fail_bad_token')
|
||||
return false // Non-retryable
|
||||
}
|
||||
|
||||
// Other 4xx (429, etc.) - retryable
|
||||
logForDebugging(
|
||||
`Failed to persist session log: ${response.status} ${response.statusText}`,
|
||||
)
|
||||
logForDiagnosticsNoPII('error', 'session_persist_fail_status', {
|
||||
status: response.status,
|
||||
attempt,
|
||||
})
|
||||
} catch (error) {
|
||||
// Network errors, 5xx - retryable
|
||||
const axiosError = error as AxiosError<SessionIngressError>
|
||||
logError(new Error(`Error persisting session log: ${axiosError.message}`))
|
||||
logForDiagnosticsNoPII('error', 'session_persist_fail_status', {
|
||||
status: axiosError.status,
|
||||
attempt,
|
||||
})
|
||||
}
|
||||
|
||||
if (attempt === MAX_RETRIES) {
|
||||
logForDebugging(`Remote persistence failed after ${MAX_RETRIES} attempts`)
|
||||
logForDiagnosticsNoPII(
|
||||
'error',
|
||||
'session_persist_error_retries_exhausted',
|
||||
{ attempt },
|
||||
)
|
||||
return false
|
||||
}
|
||||
|
||||
const delayMs = Math.min(BASE_DELAY_MS * Math.pow(2, attempt - 1), 8000)
|
||||
logForDebugging(
|
||||
`Remote persistence attempt ${attempt}/${MAX_RETRIES} failed, retrying in ${delayMs}ms…`,
|
||||
)
|
||||
await sleep(delayMs)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
/**
|
||||
* Append a log entry to the session using JWT token
|
||||
* Uses optimistic concurrency control with Last-Uuid header
|
||||
* Ensures sequential execution per session to prevent race conditions
|
||||
*/
|
||||
export async function appendSessionLog(
|
||||
sessionId: string,
|
||||
entry: TranscriptMessage,
|
||||
url: string,
|
||||
): Promise<boolean> {
|
||||
const sessionToken = getSessionIngressAuthToken()
|
||||
if (!sessionToken) {
|
||||
logForDebugging('No session token available for session persistence')
|
||||
logForDiagnosticsNoPII('error', 'session_persist_fail_jwt_no_token')
|
||||
return false
|
||||
}
|
||||
|
||||
const headers: Record<string, string> = {
|
||||
Authorization: `Bearer ${sessionToken}`,
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
const sequentialAppend = getOrCreateSequentialAppend(sessionId)
|
||||
return sequentialAppend(entry, url, headers)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all session logs for hydration
|
||||
*/
|
||||
export async function getSessionLogs(
|
||||
sessionId: string,
|
||||
url: string,
|
||||
): Promise<Entry[] | null> {
|
||||
const sessionToken = getSessionIngressAuthToken()
|
||||
if (!sessionToken) {
|
||||
logForDebugging('No session token available for fetching session logs')
|
||||
logForDiagnosticsNoPII('error', 'session_get_fail_no_token')
|
||||
return null
|
||||
}
|
||||
|
||||
const headers = { Authorization: `Bearer ${sessionToken}` }
|
||||
const logs = await fetchSessionLogsFromUrl(sessionId, url, headers)
|
||||
|
||||
if (logs && logs.length > 0) {
|
||||
// Update our lastUuid to the last entry's UUID
|
||||
const lastEntry = logs.at(-1)
|
||||
if (lastEntry && 'uuid' in lastEntry && lastEntry.uuid) {
|
||||
lastUuidMap.set(sessionId, lastEntry.uuid)
|
||||
}
|
||||
}
|
||||
|
||||
return logs
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all session logs for hydration via OAuth
|
||||
* Used for teleporting sessions from the Sessions API
|
||||
*/
|
||||
export async function getSessionLogsViaOAuth(
|
||||
sessionId: string,
|
||||
accessToken: string,
|
||||
orgUUID: string,
|
||||
): Promise<Entry[] | null> {
|
||||
const url = `${getOauthConfig().BASE_API_URL}/v1/session_ingress/session/${sessionId}`
|
||||
logForDebugging(`[session-ingress] Fetching session logs from: ${url}`)
|
||||
const headers = {
|
||||
...getOAuthHeaders(accessToken),
|
||||
'x-organization-uuid': orgUUID,
|
||||
}
|
||||
const result = await fetchSessionLogsFromUrl(sessionId, url, headers)
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* Response shape from GET /v1/code/sessions/{id}/teleport-events.
|
||||
* WorkerEvent.payload IS the Entry (TranscriptMessage struct) — the CLI
|
||||
* writes it via AddWorkerEvent, the server stores it opaque, we read it
|
||||
* back here.
|
||||
*/
|
||||
type TeleportEventsResponse = {
|
||||
data: Array<{
|
||||
event_id: string
|
||||
event_type: string
|
||||
is_compaction: boolean
|
||||
payload: Entry | null
|
||||
created_at: string
|
||||
}>
|
||||
// Unset when there are no more pages — this IS the end-of-stream
|
||||
// signal (no separate has_more field).
|
||||
next_cursor?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Get worker events (transcript) via the CCR v2 Sessions API. Replaces
|
||||
* getSessionLogsViaOAuth once session-ingress is retired.
|
||||
*
|
||||
* The server dispatches per-session: Spanner for v2-native sessions,
|
||||
* threadstore for pre-backfill session_* IDs. The cursor is opaque to us —
|
||||
* echo it back until next_cursor is unset.
|
||||
*
|
||||
* Paginated (500/page default, server max 1000). session-ingress's one-shot
|
||||
* 50k is gone; we loop.
|
||||
*/
|
||||
export async function getTeleportEvents(
|
||||
sessionId: string,
|
||||
accessToken: string,
|
||||
orgUUID: string,
|
||||
): Promise<Entry[] | null> {
|
||||
const baseUrl = `${getOauthConfig().BASE_API_URL}/v1/code/sessions/${sessionId}/teleport-events`
|
||||
const headers = {
|
||||
...getOAuthHeaders(accessToken),
|
||||
'x-organization-uuid': orgUUID,
|
||||
}
|
||||
|
||||
logForDebugging(`[teleport] Fetching events from: ${baseUrl}`)
|
||||
|
||||
const all: Entry[] = []
|
||||
let cursor: string | undefined
|
||||
let pages = 0
|
||||
|
||||
// Infinite-loop guard: 1000/page × 100 pages = 100k events. Larger than
|
||||
// session-ingress's 50k one-shot. If we hit this, something's wrong
|
||||
// (server not advancing cursor) — bail rather than hang.
|
||||
const maxPages = 100
|
||||
|
||||
while (pages < maxPages) {
|
||||
const params: Record<string, string | number> = { limit: 1000 }
|
||||
if (cursor !== undefined) {
|
||||
params.cursor = cursor
|
||||
}
|
||||
|
||||
let response
|
||||
try {
|
||||
response = await axios.get<TeleportEventsResponse>(baseUrl, {
|
||||
headers,
|
||||
params,
|
||||
timeout: 20000,
|
||||
validateStatus: status => status < 500,
|
||||
})
|
||||
} catch (e) {
|
||||
const err = e as AxiosError
|
||||
logError(new Error(`Teleport events fetch failed: ${err.message}`))
|
||||
logForDiagnosticsNoPII('error', 'teleport_events_fetch_fail')
|
||||
return null
|
||||
}
|
||||
|
||||
if (response.status === 404) {
|
||||
// 404 on page 0 is ambiguous during the migration window:
|
||||
// (a) Session genuinely not found (not in Spanner AND not in
|
||||
// threadstore) — nothing to fetch.
|
||||
// (b) Route-level 404: endpoint not deployed yet, or session is
|
||||
// a threadstore session not yet backfilled into Spanner.
|
||||
// We can't tell them apart from the response alone. Returning null
|
||||
// lets the caller fall back to session-ingress, which will correctly
|
||||
// return empty for case (a) and data for case (b). Once the backfill
|
||||
// is complete and session-ingress is gone, the fallback also returns
|
||||
// null → same "Failed to fetch session logs" error as today.
|
||||
//
|
||||
// 404 mid-pagination (pages > 0) means session was deleted between
|
||||
// pages — return what we have.
|
||||
logForDebugging(
|
||||
`[teleport] Session ${sessionId} not found (page ${pages})`,
|
||||
)
|
||||
logForDiagnosticsNoPII('warn', 'teleport_events_not_found')
|
||||
return pages === 0 ? null : all
|
||||
}
|
||||
|
||||
if (response.status === 401) {
|
||||
logForDiagnosticsNoPII('error', 'teleport_events_bad_token')
|
||||
throw new Error(
|
||||
'Your session has expired. Please run /login to sign in again.',
|
||||
)
|
||||
}
|
||||
|
||||
if (response.status !== 200) {
|
||||
logError(
|
||||
new Error(
|
||||
`Teleport events returned ${response.status}: ${jsonStringify(response.data)}`,
|
||||
),
|
||||
)
|
||||
logForDiagnosticsNoPII('error', 'teleport_events_bad_status')
|
||||
return null
|
||||
}
|
||||
|
||||
const { data, next_cursor } = response.data
|
||||
if (!Array.isArray(data)) {
|
||||
logError(
|
||||
new Error(
|
||||
`Teleport events invalid response shape: ${jsonStringify(response.data)}`,
|
||||
),
|
||||
)
|
||||
logForDiagnosticsNoPII('error', 'teleport_events_invalid_shape')
|
||||
return null
|
||||
}
|
||||
|
||||
// payload IS the Entry. null payload happens for threadstore non-generic
|
||||
// events (server skips them) or encryption failures — skip here too.
|
||||
for (const ev of data) {
|
||||
if (ev.payload !== null) {
|
||||
all.push(ev.payload)
|
||||
}
|
||||
}
|
||||
|
||||
pages++
|
||||
// == null covers both `null` and `undefined` — the proto omits the
|
||||
// field at end-of-stream, but some serializers emit `null`. Strict
|
||||
// `=== undefined` would loop forever on `null` (cursor=null in query
|
||||
// params stringifies to "null", which the server rejects or echoes).
|
||||
if (next_cursor == null) {
|
||||
break
|
||||
}
|
||||
cursor = next_cursor
|
||||
}
|
||||
|
||||
if (pages >= maxPages) {
|
||||
// Don't fail — return what we have. Better to teleport with a
|
||||
// truncated transcript than not at all.
|
||||
logError(
|
||||
new Error(`Teleport events hit page cap (${maxPages}) for ${sessionId}`),
|
||||
)
|
||||
logForDiagnosticsNoPII('warn', 'teleport_events_page_cap')
|
||||
}
|
||||
|
||||
logForDebugging(
|
||||
`[teleport] Fetched ${all.length} events over ${pages} page(s) for ${sessionId}`,
|
||||
)
|
||||
return all
|
||||
}
|
||||
|
||||
/**
|
||||
* Shared implementation for fetching session logs from a URL
|
||||
*/
|
||||
async function fetchSessionLogsFromUrl(
|
||||
sessionId: string,
|
||||
url: string,
|
||||
headers: Record<string, string>,
|
||||
): Promise<Entry[] | null> {
|
||||
try {
|
||||
const response = await axios.get(url, {
|
||||
headers,
|
||||
timeout: 20000,
|
||||
validateStatus: status => status < 500,
|
||||
params: isEnvTruthy(process.env.CLAUDE_AFTER_LAST_COMPACT)
|
||||
? { after_last_compact: true }
|
||||
: undefined,
|
||||
})
|
||||
|
||||
if (response.status === 200) {
|
||||
const data = response.data
|
||||
|
||||
// Validate the response structure
|
||||
if (!data || typeof data !== 'object' || !Array.isArray(data.loglines)) {
|
||||
logError(
|
||||
new Error(
|
||||
`Invalid session logs response format: ${jsonStringify(data)}`,
|
||||
),
|
||||
)
|
||||
logForDiagnosticsNoPII('error', 'session_get_fail_invalid_response')
|
||||
return null
|
||||
}
|
||||
|
||||
const logs = data.loglines as Entry[]
|
||||
logForDebugging(
|
||||
`Fetched ${logs.length} session logs for session ${sessionId}`,
|
||||
)
|
||||
return logs
|
||||
}
|
||||
|
||||
if (response.status === 404) {
|
||||
logForDebugging(`No existing logs for session ${sessionId}`)
|
||||
logForDiagnosticsNoPII('warn', 'session_get_no_logs_for_session')
|
||||
return []
|
||||
}
|
||||
|
||||
if (response.status === 401) {
|
||||
logForDebugging('Auth token expired or invalid')
|
||||
logForDiagnosticsNoPII('error', 'session_get_fail_bad_token')
|
||||
throw new Error(
|
||||
'Your session has expired. Please run /login to sign in again.',
|
||||
)
|
||||
}
|
||||
|
||||
logForDebugging(
|
||||
`Failed to fetch session logs: ${response.status} ${response.statusText}`,
|
||||
)
|
||||
logForDiagnosticsNoPII('error', 'session_get_fail_status', {
|
||||
status: response.status,
|
||||
})
|
||||
return null
|
||||
} catch (error) {
|
||||
const axiosError = error as AxiosError<SessionIngressError>
|
||||
logError(new Error(`Error fetching session logs: ${axiosError.message}`))
|
||||
logForDiagnosticsNoPII('error', 'session_get_fail_status', {
|
||||
status: axiosError.status,
|
||||
})
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Walk backward through entries to find the last one with a uuid.
|
||||
* Some entry types (SummaryMessage, TagMessage) don't have one.
|
||||
*/
|
||||
function findLastUuid(logs: Entry[] | null): UUID | undefined {
|
||||
if (!logs) {
|
||||
return undefined
|
||||
}
|
||||
const entry = logs.findLast(e => 'uuid' in e && e.uuid)
|
||||
return entry && 'uuid' in entry ? (entry.uuid as UUID) : undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear cached state for a session
|
||||
*/
|
||||
export function clearSession(sessionId: string): void {
|
||||
lastUuidMap.delete(sessionId)
|
||||
sequentialAppendBySession.delete(sessionId)
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear all cached session state (all sessions).
|
||||
* Use this on /clear to free sub-agent session entries.
|
||||
*/
|
||||
export function clearAllSessions(): void {
|
||||
lastUuidMap.clear()
|
||||
sequentialAppendBySession.clear()
|
||||
}
|
||||
38
src/services/api/ultrareviewQuota.ts
Normal file
38
src/services/api/ultrareviewQuota.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
import axios from 'axios'
|
||||
import { getOauthConfig } from '../../constants/oauth.js'
|
||||
import { isClaudeAISubscriber } from '../../utils/auth.js'
|
||||
import { logForDebugging } from '../../utils/debug.js'
|
||||
import { getOAuthHeaders, prepareApiRequest } from '../../utils/teleport/api.js'
|
||||
|
||||
export type UltrareviewQuotaResponse = {
|
||||
reviews_used: number
|
||||
reviews_limit: number
|
||||
reviews_remaining: number
|
||||
is_overage: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* Peek the ultrareview quota for display and nudge decisions. Consume
|
||||
* happens server-side at session creation. Null when not a subscriber or
|
||||
* the endpoint errors.
|
||||
*/
|
||||
export async function fetchUltrareviewQuota(): Promise<UltrareviewQuotaResponse | null> {
|
||||
if (!isClaudeAISubscriber()) return null
|
||||
try {
|
||||
const { accessToken, orgUUID } = await prepareApiRequest()
|
||||
const response = await axios.get<UltrareviewQuotaResponse>(
|
||||
`${getOauthConfig().BASE_API_URL}/v1/ultrareview/quota`,
|
||||
{
|
||||
headers: {
|
||||
...getOAuthHeaders(accessToken),
|
||||
'x-organization-uuid': orgUUID,
|
||||
},
|
||||
timeout: 5000,
|
||||
},
|
||||
)
|
||||
return response.data
|
||||
} catch (error) {
|
||||
logForDebugging(`fetchUltrareviewQuota failed: ${error}`)
|
||||
return null
|
||||
}
|
||||
}
|
||||
63
src/services/api/usage.ts
Normal file
63
src/services/api/usage.ts
Normal file
@@ -0,0 +1,63 @@
|
||||
import axios from 'axios'
|
||||
import { getOauthConfig } from '../../constants/oauth.js'
|
||||
import {
|
||||
getClaudeAIOAuthTokens,
|
||||
hasProfileScope,
|
||||
isClaudeAISubscriber,
|
||||
} from '../../utils/auth.js'
|
||||
import { getAuthHeaders } from '../../utils/http.js'
|
||||
import { getClaudeCodeUserAgent } from '../../utils/userAgent.js'
|
||||
import { isOAuthTokenExpired } from '../oauth/client.js'
|
||||
|
||||
export type RateLimit = {
|
||||
utilization: number | null // a percentage from 0 to 100
|
||||
resets_at: string | null // ISO 8601 timestamp
|
||||
}
|
||||
|
||||
export type ExtraUsage = {
|
||||
is_enabled: boolean
|
||||
monthly_limit: number | null
|
||||
used_credits: number | null
|
||||
utilization: number | null
|
||||
}
|
||||
|
||||
export type Utilization = {
|
||||
five_hour?: RateLimit | null
|
||||
seven_day?: RateLimit | null
|
||||
seven_day_oauth_apps?: RateLimit | null
|
||||
seven_day_opus?: RateLimit | null
|
||||
seven_day_sonnet?: RateLimit | null
|
||||
extra_usage?: ExtraUsage | null
|
||||
}
|
||||
|
||||
export async function fetchUtilization(): Promise<Utilization | null> {
|
||||
if (!isClaudeAISubscriber() || !hasProfileScope()) {
|
||||
return {}
|
||||
}
|
||||
|
||||
// Skip API call if OAuth token is expired to avoid 401 errors
|
||||
const tokens = getClaudeAIOAuthTokens()
|
||||
if (tokens && isOAuthTokenExpired(tokens.expiresAt)) {
|
||||
return null
|
||||
}
|
||||
|
||||
const authResult = getAuthHeaders()
|
||||
if (authResult.error) {
|
||||
throw new Error(`Auth error: ${authResult.error}`)
|
||||
}
|
||||
|
||||
const headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'User-Agent': getClaudeCodeUserAgent(),
|
||||
...authResult.headers,
|
||||
}
|
||||
|
||||
const url = `${getOauthConfig().BASE_API_URL}/api/oauth/usage`
|
||||
|
||||
const response = await axios.get<Utilization>(url, {
|
||||
headers,
|
||||
timeout: 5000, // 5 second timeout
|
||||
})
|
||||
|
||||
return response.data
|
||||
}
|
||||
822
src/services/api/withRetry.ts
Normal file
822
src/services/api/withRetry.ts
Normal file
@@ -0,0 +1,822 @@
|
||||
import { feature } from 'bun:bundle'
|
||||
import type Anthropic from '@anthropic-ai/sdk'
|
||||
import {
|
||||
APIConnectionError,
|
||||
APIError,
|
||||
APIUserAbortError,
|
||||
} from '@anthropic-ai/sdk'
|
||||
import type { QuerySource } from 'src/constants/querySource.js'
|
||||
import type { SystemAPIErrorMessage } from 'src/types/message.js'
|
||||
import { isAwsCredentialsProviderError } from 'src/utils/aws.js'
|
||||
import { logForDebugging } from 'src/utils/debug.js'
|
||||
import { logError } from 'src/utils/log.js'
|
||||
import { createSystemAPIErrorMessage } from 'src/utils/messages.js'
|
||||
import { getAPIProviderForStatsig } from 'src/utils/model/providers.js'
|
||||
import {
|
||||
clearApiKeyHelperCache,
|
||||
clearAwsCredentialsCache,
|
||||
clearGcpCredentialsCache,
|
||||
getClaudeAIOAuthTokens,
|
||||
handleOAuth401Error,
|
||||
isClaudeAISubscriber,
|
||||
isEnterpriseSubscriber,
|
||||
} from '../../utils/auth.js'
|
||||
import { isEnvTruthy } from '../../utils/envUtils.js'
|
||||
import { errorMessage } from '../../utils/errors.js'
|
||||
import {
|
||||
type CooldownReason,
|
||||
handleFastModeOverageRejection,
|
||||
handleFastModeRejectedByAPI,
|
||||
isFastModeCooldown,
|
||||
isFastModeEnabled,
|
||||
triggerFastModeCooldown,
|
||||
} from '../../utils/fastMode.js'
|
||||
import { isNonCustomOpusModel } from '../../utils/model/model.js'
|
||||
import { disableKeepAlive } from '../../utils/proxy.js'
|
||||
import { sleep } from '../../utils/sleep.js'
|
||||
import type { ThinkingConfig } from '../../utils/thinking.js'
|
||||
import { getFeatureValue_CACHED_MAY_BE_STALE } from '../analytics/growthbook.js'
|
||||
import {
|
||||
type AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
logEvent,
|
||||
} from '../analytics/index.js'
|
||||
import {
|
||||
checkMockRateLimitError,
|
||||
isMockRateLimitError,
|
||||
} from '../rateLimitMocking.js'
|
||||
import { REPEATED_529_ERROR_MESSAGE } from './errors.js'
|
||||
import { extractConnectionErrorDetails } from './errorUtils.js'
|
||||
|
||||
const abortError = () => new APIUserAbortError()
|
||||
|
||||
const DEFAULT_MAX_RETRIES = 10
|
||||
const FLOOR_OUTPUT_TOKENS = 3000
|
||||
const MAX_529_RETRIES = 3
|
||||
export const BASE_DELAY_MS = 500
|
||||
|
||||
// Foreground query sources where the user IS blocking on the result — these
|
||||
// retry on 529. Everything else (summaries, titles, suggestions, classifiers)
|
||||
// bails immediately: during a capacity cascade each retry is 3-10× gateway
|
||||
// amplification, and the user never sees those fail anyway. New sources
|
||||
// default to no-retry — add here only if the user is waiting on the result.
|
||||
const FOREGROUND_529_RETRY_SOURCES = new Set<QuerySource>([
|
||||
'repl_main_thread',
|
||||
'repl_main_thread:outputStyle:custom',
|
||||
'repl_main_thread:outputStyle:Explanatory',
|
||||
'repl_main_thread:outputStyle:Learning',
|
||||
'sdk',
|
||||
'agent:custom',
|
||||
'agent:default',
|
||||
'agent:builtin',
|
||||
'compact',
|
||||
'hook_agent',
|
||||
'hook_prompt',
|
||||
'verification_agent',
|
||||
'side_question',
|
||||
// Security classifiers — must complete for auto-mode correctness.
|
||||
// yoloClassifier.ts uses 'auto_mode' (not 'yolo_classifier' — that's
|
||||
// type-only). bash_classifier is ant-only; feature-gate so the string
|
||||
// tree-shakes out of external builds (excluded-strings.txt).
|
||||
'auto_mode',
|
||||
...(feature('BASH_CLASSIFIER') ? (['bash_classifier'] as const) : []),
|
||||
])
|
||||
|
||||
function shouldRetry529(querySource: QuerySource | undefined): boolean {
|
||||
// undefined → retry (conservative for untagged call paths)
|
||||
return (
|
||||
querySource === undefined || FOREGROUND_529_RETRY_SOURCES.has(querySource)
|
||||
)
|
||||
}
|
||||
|
||||
// CLAUDE_CODE_UNATTENDED_RETRY: for unattended sessions (ant-only). Retries 429/529
|
||||
// indefinitely with higher backoff and periodic keep-alive yields so the host
|
||||
// environment does not mark the session idle mid-wait.
|
||||
// TODO(ANT-344): the keep-alive via SystemAPIErrorMessage yields is a stopgap
|
||||
// until there's a dedicated keep-alive channel.
|
||||
const PERSISTENT_MAX_BACKOFF_MS = 5 * 60 * 1000
|
||||
const PERSISTENT_RESET_CAP_MS = 6 * 60 * 60 * 1000
|
||||
const HEARTBEAT_INTERVAL_MS = 30_000
|
||||
|
||||
function isPersistentRetryEnabled(): boolean {
|
||||
return feature('UNATTENDED_RETRY')
|
||||
? isEnvTruthy(process.env.CLAUDE_CODE_UNATTENDED_RETRY)
|
||||
: false
|
||||
}
|
||||
|
||||
function isTransientCapacityError(error: unknown): boolean {
|
||||
return (
|
||||
is529Error(error) || (error instanceof APIError && error.status === 429)
|
||||
)
|
||||
}
|
||||
|
||||
function isStaleConnectionError(error: unknown): boolean {
|
||||
if (!(error instanceof APIConnectionError)) {
|
||||
return false
|
||||
}
|
||||
const details = extractConnectionErrorDetails(error)
|
||||
return details?.code === 'ECONNRESET' || details?.code === 'EPIPE'
|
||||
}
|
||||
|
||||
export interface RetryContext {
|
||||
maxTokensOverride?: number
|
||||
model: string
|
||||
thinkingConfig: ThinkingConfig
|
||||
fastMode?: boolean
|
||||
}
|
||||
|
||||
interface RetryOptions {
|
||||
maxRetries?: number
|
||||
model: string
|
||||
fallbackModel?: string
|
||||
thinkingConfig: ThinkingConfig
|
||||
fastMode?: boolean
|
||||
signal?: AbortSignal
|
||||
querySource?: QuerySource
|
||||
/**
|
||||
* Pre-seed the consecutive 529 counter. Used when this retry loop is a
|
||||
* non-streaming fallback after a streaming 529 — the streaming 529 should
|
||||
* count toward MAX_529_RETRIES so total 529s-before-fallback is consistent
|
||||
* regardless of which request mode hit the overload.
|
||||
*/
|
||||
initialConsecutive529Errors?: number
|
||||
}
|
||||
|
||||
export class CannotRetryError extends Error {
|
||||
constructor(
|
||||
public readonly originalError: unknown,
|
||||
public readonly retryContext: RetryContext,
|
||||
) {
|
||||
const message = errorMessage(originalError)
|
||||
super(message)
|
||||
this.name = 'RetryError'
|
||||
|
||||
// Preserve the original stack trace if available
|
||||
if (originalError instanceof Error && originalError.stack) {
|
||||
this.stack = originalError.stack
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export class FallbackTriggeredError extends Error {
|
||||
constructor(
|
||||
public readonly originalModel: string,
|
||||
public readonly fallbackModel: string,
|
||||
) {
|
||||
super(`Model fallback triggered: ${originalModel} -> ${fallbackModel}`)
|
||||
this.name = 'FallbackTriggeredError'
|
||||
}
|
||||
}
|
||||
|
||||
export async function* withRetry<T>(
|
||||
getClient: () => Promise<Anthropic>,
|
||||
operation: (
|
||||
client: Anthropic,
|
||||
attempt: number,
|
||||
context: RetryContext,
|
||||
) => Promise<T>,
|
||||
options: RetryOptions,
|
||||
): AsyncGenerator<SystemAPIErrorMessage, T> {
|
||||
const maxRetries = getMaxRetries(options)
|
||||
const retryContext: RetryContext = {
|
||||
model: options.model,
|
||||
thinkingConfig: options.thinkingConfig,
|
||||
...(isFastModeEnabled() && { fastMode: options.fastMode }),
|
||||
}
|
||||
let client: Anthropic | null = null
|
||||
let consecutive529Errors = options.initialConsecutive529Errors ?? 0
|
||||
let lastError: unknown
|
||||
let persistentAttempt = 0
|
||||
for (let attempt = 1; attempt <= maxRetries + 1; attempt++) {
|
||||
if (options.signal?.aborted) {
|
||||
throw new APIUserAbortError()
|
||||
}
|
||||
|
||||
// Capture whether fast mode is active before this attempt
|
||||
// (fallback may change the state mid-loop)
|
||||
const wasFastModeActive = isFastModeEnabled()
|
||||
? retryContext.fastMode && !isFastModeCooldown()
|
||||
: false
|
||||
|
||||
try {
|
||||
// Check for mock rate limits (used by /mock-limits command for Ant employees)
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
const mockError = checkMockRateLimitError(
|
||||
retryContext.model,
|
||||
wasFastModeActive,
|
||||
)
|
||||
if (mockError) {
|
||||
throw mockError
|
||||
}
|
||||
}
|
||||
|
||||
// Get a fresh client instance on first attempt or after authentication errors
|
||||
// - 401 for first-party API authentication failures
|
||||
// - 403 "OAuth token has been revoked" (another process refreshed the token)
|
||||
// - Bedrock-specific auth errors (403 or CredentialsProviderError)
|
||||
// - Vertex-specific auth errors (credential refresh failures, 401)
|
||||
// - ECONNRESET/EPIPE: stale keep-alive socket; disable pooling and reconnect
|
||||
const isStaleConnection = isStaleConnectionError(lastError)
|
||||
if (
|
||||
isStaleConnection &&
|
||||
getFeatureValue_CACHED_MAY_BE_STALE(
|
||||
'tengu_disable_keepalive_on_econnreset',
|
||||
false,
|
||||
)
|
||||
) {
|
||||
logForDebugging(
|
||||
'Stale connection (ECONNRESET/EPIPE) — disabling keep-alive for retry',
|
||||
)
|
||||
disableKeepAlive()
|
||||
}
|
||||
|
||||
if (
|
||||
client === null ||
|
||||
(lastError instanceof APIError && lastError.status === 401) ||
|
||||
isOAuthTokenRevokedError(lastError) ||
|
||||
isBedrockAuthError(lastError) ||
|
||||
isVertexAuthError(lastError) ||
|
||||
isStaleConnection
|
||||
) {
|
||||
// On 401 "token expired" or 403 "token revoked", force a token refresh
|
||||
if (
|
||||
(lastError instanceof APIError && lastError.status === 401) ||
|
||||
isOAuthTokenRevokedError(lastError)
|
||||
) {
|
||||
const failedAccessToken = getClaudeAIOAuthTokens()?.accessToken
|
||||
if (failedAccessToken) {
|
||||
await handleOAuth401Error(failedAccessToken)
|
||||
}
|
||||
}
|
||||
client = await getClient()
|
||||
}
|
||||
|
||||
return await operation(client, attempt, retryContext)
|
||||
} catch (error) {
|
||||
lastError = error
|
||||
logForDebugging(
|
||||
`API error (attempt ${attempt}/${maxRetries + 1}): ${error instanceof APIError ? `${error.status} ${error.message}` : errorMessage(error)}`,
|
||||
{ level: 'error' },
|
||||
)
|
||||
|
||||
// Fast mode fallback: on 429/529, either wait and retry (short delays)
|
||||
// or fall back to standard speed (long delays) to avoid cache thrashing.
|
||||
// Skip in persistent mode: the short-retry path below loops with fast
|
||||
// mode still active, so its `continue` never reaches the attempt clamp
|
||||
// and the for-loop terminates. Persistent sessions want the chunked
|
||||
// keep-alive path instead of fast-mode cache-preservation anyway.
|
||||
if (
|
||||
wasFastModeActive &&
|
||||
!isPersistentRetryEnabled() &&
|
||||
error instanceof APIError &&
|
||||
(error.status === 429 || is529Error(error))
|
||||
) {
|
||||
// If the 429 is specifically because extra usage (overage) is not
|
||||
// available, permanently disable fast mode with a specific message.
|
||||
const overageReason = error.headers?.get(
|
||||
'anthropic-ratelimit-unified-overage-disabled-reason',
|
||||
)
|
||||
if (overageReason !== null && overageReason !== undefined) {
|
||||
handleFastModeOverageRejection(overageReason)
|
||||
retryContext.fastMode = false
|
||||
continue
|
||||
}
|
||||
|
||||
const retryAfterMs = getRetryAfterMs(error)
|
||||
if (retryAfterMs !== null && retryAfterMs < SHORT_RETRY_THRESHOLD_MS) {
|
||||
// Short retry-after: wait and retry with fast mode still active
|
||||
// to preserve prompt cache (same model name on retry).
|
||||
await sleep(retryAfterMs, options.signal, { abortError })
|
||||
continue
|
||||
}
|
||||
// Long or unknown retry-after: enter cooldown (switches to standard
|
||||
// speed model), with a minimum floor to avoid flip-flopping.
|
||||
const cooldownMs = Math.max(
|
||||
retryAfterMs ?? DEFAULT_FAST_MODE_FALLBACK_HOLD_MS,
|
||||
MIN_COOLDOWN_MS,
|
||||
)
|
||||
const cooldownReason: CooldownReason = is529Error(error)
|
||||
? 'overloaded'
|
||||
: 'rate_limit'
|
||||
triggerFastModeCooldown(Date.now() + cooldownMs, cooldownReason)
|
||||
if (isFastModeEnabled()) {
|
||||
retryContext.fastMode = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Fast mode fallback: if the API rejects the fast mode parameter
|
||||
// (e.g., org doesn't have fast mode enabled), permanently disable fast
|
||||
// mode and retry at standard speed.
|
||||
if (wasFastModeActive && isFastModeNotEnabledError(error)) {
|
||||
handleFastModeRejectedByAPI()
|
||||
retryContext.fastMode = false
|
||||
continue
|
||||
}
|
||||
|
||||
// Non-foreground sources bail immediately on 529 — no retry amplification
|
||||
// during capacity cascades. User never sees these fail.
|
||||
if (is529Error(error) && !shouldRetry529(options.querySource)) {
|
||||
logEvent('tengu_api_529_background_dropped', {
|
||||
query_source:
|
||||
options.querySource as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
throw new CannotRetryError(error, retryContext)
|
||||
}
|
||||
|
||||
// Track consecutive 529 errors
|
||||
if (
|
||||
is529Error(error) &&
|
||||
// If FALLBACK_FOR_ALL_PRIMARY_MODELS is not set, fall through only if the primary model is a non-custom Opus model.
|
||||
// TODO: Revisit if the isNonCustomOpusModel check should still exist, or if isNonCustomOpusModel is a stale artifact of when Claude Code was hardcoded on Opus.
|
||||
(process.env.FALLBACK_FOR_ALL_PRIMARY_MODELS ||
|
||||
(!isClaudeAISubscriber() && isNonCustomOpusModel(options.model)))
|
||||
) {
|
||||
consecutive529Errors++
|
||||
if (consecutive529Errors >= MAX_529_RETRIES) {
|
||||
// Check if fallback model is specified
|
||||
if (options.fallbackModel) {
|
||||
logEvent('tengu_api_opus_fallback_triggered', {
|
||||
original_model:
|
||||
options.model as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
fallback_model:
|
||||
options.fallbackModel as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
provider: getAPIProviderForStatsig(),
|
||||
})
|
||||
|
||||
// Throw special error to indicate fallback was triggered
|
||||
throw new FallbackTriggeredError(
|
||||
options.model,
|
||||
options.fallbackModel,
|
||||
)
|
||||
}
|
||||
|
||||
if (
|
||||
process.env.USER_TYPE === 'external' &&
|
||||
!process.env.IS_SANDBOX &&
|
||||
!isPersistentRetryEnabled()
|
||||
) {
|
||||
logEvent('tengu_api_custom_529_overloaded_error', {})
|
||||
throw new CannotRetryError(
|
||||
new Error(REPEATED_529_ERROR_MESSAGE),
|
||||
retryContext,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Only retry if the error indicates we should
|
||||
const persistent =
|
||||
isPersistentRetryEnabled() && isTransientCapacityError(error)
|
||||
if (attempt > maxRetries && !persistent) {
|
||||
throw new CannotRetryError(error, retryContext)
|
||||
}
|
||||
|
||||
// AWS/GCP errors aren't always APIError, but can be retried
|
||||
const handledCloudAuthError =
|
||||
handleAwsCredentialError(error) || handleGcpCredentialError(error)
|
||||
if (
|
||||
!handledCloudAuthError &&
|
||||
(!(error instanceof APIError) || !shouldRetry(error))
|
||||
) {
|
||||
throw new CannotRetryError(error, retryContext)
|
||||
}
|
||||
|
||||
// Handle max tokens context overflow errors by adjusting max_tokens for the next attempt
|
||||
// NOTE: With extended-context-window beta, this 400 error should not occur.
|
||||
// The API now returns 'model_context_window_exceeded' stop_reason instead.
|
||||
// Keeping for backward compatibility.
|
||||
if (error instanceof APIError) {
|
||||
const overflowData = parseMaxTokensContextOverflowError(error)
|
||||
if (overflowData) {
|
||||
const { inputTokens, contextLimit } = overflowData
|
||||
|
||||
const safetyBuffer = 1000
|
||||
const availableContext = Math.max(
|
||||
0,
|
||||
contextLimit - inputTokens - safetyBuffer,
|
||||
)
|
||||
if (availableContext < FLOOR_OUTPUT_TOKENS) {
|
||||
logError(
|
||||
new Error(
|
||||
`availableContext ${availableContext} is less than FLOOR_OUTPUT_TOKENS ${FLOOR_OUTPUT_TOKENS}`,
|
||||
),
|
||||
)
|
||||
throw error
|
||||
}
|
||||
// Ensure we have enough tokens for thinking + at least 1 output token
|
||||
const minRequired =
|
||||
(retryContext.thinkingConfig.type === 'enabled'
|
||||
? retryContext.thinkingConfig.budgetTokens
|
||||
: 0) + 1
|
||||
const adjustedMaxTokens = Math.max(
|
||||
FLOOR_OUTPUT_TOKENS,
|
||||
availableContext,
|
||||
minRequired,
|
||||
)
|
||||
retryContext.maxTokensOverride = adjustedMaxTokens
|
||||
|
||||
logEvent('tengu_max_tokens_context_overflow_adjustment', {
|
||||
inputTokens,
|
||||
contextLimit,
|
||||
adjustedMaxTokens,
|
||||
attempt,
|
||||
})
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// For other errors, proceed with normal retry logic
|
||||
// Get retry-after header if available
|
||||
const retryAfter = getRetryAfter(error)
|
||||
let delayMs: number
|
||||
if (persistent && error instanceof APIError && error.status === 429) {
|
||||
persistentAttempt++
|
||||
// Window-based limits (e.g. 5hr Max/Pro) include a reset timestamp.
|
||||
// Wait until reset rather than polling every 5 min uselessly.
|
||||
const resetDelay = getRateLimitResetDelayMs(error)
|
||||
delayMs =
|
||||
resetDelay ??
|
||||
Math.min(
|
||||
getRetryDelay(
|
||||
persistentAttempt,
|
||||
retryAfter,
|
||||
PERSISTENT_MAX_BACKOFF_MS,
|
||||
),
|
||||
PERSISTENT_RESET_CAP_MS,
|
||||
)
|
||||
} else if (persistent) {
|
||||
persistentAttempt++
|
||||
// Retry-After is a server directive and bypasses maxDelayMs inside
|
||||
// getRetryDelay (intentional — honoring it is correct). Cap at the
|
||||
// 6hr reset-cap here so a pathological header can't wait unbounded.
|
||||
delayMs = Math.min(
|
||||
getRetryDelay(
|
||||
persistentAttempt,
|
||||
retryAfter,
|
||||
PERSISTENT_MAX_BACKOFF_MS,
|
||||
),
|
||||
PERSISTENT_RESET_CAP_MS,
|
||||
)
|
||||
} else {
|
||||
delayMs = getRetryDelay(attempt, retryAfter)
|
||||
}
|
||||
|
||||
// In persistent mode the for-loop `attempt` is clamped at maxRetries+1;
|
||||
// use persistentAttempt for telemetry/yields so they show the true count.
|
||||
const reportedAttempt = persistent ? persistentAttempt : attempt
|
||||
logEvent('tengu_api_retry', {
|
||||
attempt: reportedAttempt,
|
||||
delayMs: delayMs,
|
||||
error: (error as APIError)
|
||||
.message as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
status: (error as APIError).status,
|
||||
provider: getAPIProviderForStatsig(),
|
||||
})
|
||||
|
||||
if (persistent) {
|
||||
if (delayMs > 60_000) {
|
||||
logEvent('tengu_api_persistent_retry_wait', {
|
||||
status: (error as APIError).status,
|
||||
delayMs,
|
||||
attempt: reportedAttempt,
|
||||
provider: getAPIProviderForStatsig(),
|
||||
})
|
||||
}
|
||||
// Chunk long sleeps so the host sees periodic stdout activity and
|
||||
// does not mark the session idle. Each yield surfaces as
|
||||
// {type:'system', subtype:'api_retry'} on stdout via QueryEngine.
|
||||
let remaining = delayMs
|
||||
while (remaining > 0) {
|
||||
if (options.signal?.aborted) throw new APIUserAbortError()
|
||||
if (error instanceof APIError) {
|
||||
yield createSystemAPIErrorMessage(
|
||||
error,
|
||||
remaining,
|
||||
reportedAttempt,
|
||||
maxRetries,
|
||||
)
|
||||
}
|
||||
const chunk = Math.min(remaining, HEARTBEAT_INTERVAL_MS)
|
||||
await sleep(chunk, options.signal, { abortError })
|
||||
remaining -= chunk
|
||||
}
|
||||
// Clamp so the for-loop never terminates. Backoff uses the separate
|
||||
// persistentAttempt counter which keeps growing to the 5-min cap.
|
||||
if (attempt >= maxRetries) attempt = maxRetries
|
||||
} else {
|
||||
if (error instanceof APIError) {
|
||||
yield createSystemAPIErrorMessage(error, delayMs, attempt, maxRetries)
|
||||
}
|
||||
await sleep(delayMs, options.signal, { abortError })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
throw new CannotRetryError(lastError, retryContext)
|
||||
}
|
||||
|
||||
function getRetryAfter(error: unknown): string | null {
|
||||
return (
|
||||
((error as { headers?: { 'retry-after'?: string } }).headers?.[
|
||||
'retry-after'
|
||||
] ||
|
||||
// eslint-disable-next-line eslint-plugin-n/no-unsupported-features/node-builtins
|
||||
((error as APIError).headers as Headers)?.get?.('retry-after')) ??
|
||||
null
|
||||
)
|
||||
}
|
||||
|
||||
export function getRetryDelay(
|
||||
attempt: number,
|
||||
retryAfterHeader?: string | null,
|
||||
maxDelayMs = 32000,
|
||||
): number {
|
||||
if (retryAfterHeader) {
|
||||
const seconds = parseInt(retryAfterHeader, 10)
|
||||
if (!isNaN(seconds)) {
|
||||
return seconds * 1000
|
||||
}
|
||||
}
|
||||
|
||||
const baseDelay = Math.min(
|
||||
BASE_DELAY_MS * Math.pow(2, attempt - 1),
|
||||
maxDelayMs,
|
||||
)
|
||||
const jitter = Math.random() * 0.25 * baseDelay
|
||||
return baseDelay + jitter
|
||||
}
|
||||
|
||||
export function parseMaxTokensContextOverflowError(error: APIError):
|
||||
| {
|
||||
inputTokens: number
|
||||
maxTokens: number
|
||||
contextLimit: number
|
||||
}
|
||||
| undefined {
|
||||
if (error.status !== 400 || !error.message) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
if (
|
||||
!error.message.includes(
|
||||
'input length and `max_tokens` exceed context limit',
|
||||
)
|
||||
) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
// Example format: "input length and `max_tokens` exceed context limit: 188059 + 20000 > 200000"
|
||||
const regex =
|
||||
/input length and `max_tokens` exceed context limit: (\d+) \+ (\d+) > (\d+)/
|
||||
const match = error.message.match(regex)
|
||||
|
||||
if (!match || match.length !== 4) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
if (!match[1] || !match[2] || !match[3]) {
|
||||
logError(
|
||||
new Error(
|
||||
'Unable to parse max_tokens from max_tokens exceed context limit error message',
|
||||
),
|
||||
)
|
||||
return undefined
|
||||
}
|
||||
const inputTokens = parseInt(match[1], 10)
|
||||
const maxTokens = parseInt(match[2], 10)
|
||||
const contextLimit = parseInt(match[3], 10)
|
||||
|
||||
if (isNaN(inputTokens) || isNaN(maxTokens) || isNaN(contextLimit)) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
return { inputTokens, maxTokens, contextLimit }
|
||||
}
|
||||
|
||||
// TODO: Replace with a response header check once the API adds a dedicated
|
||||
// header for fast-mode rejection (e.g., x-fast-mode-rejected). String-matching
|
||||
// the error message is fragile and will break if the API wording changes.
|
||||
function isFastModeNotEnabledError(error: unknown): boolean {
|
||||
if (!(error instanceof APIError)) {
|
||||
return false
|
||||
}
|
||||
return (
|
||||
error.status === 400 &&
|
||||
(error.message?.includes('Fast mode is not enabled') ?? false)
|
||||
)
|
||||
}
|
||||
|
||||
export function is529Error(error: unknown): boolean {
|
||||
if (!(error instanceof APIError)) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for 529 status code or overloaded error in message
|
||||
return (
|
||||
error.status === 529 ||
|
||||
// See below: the SDK sometimes fails to properly pass the 529 status code during streaming
|
||||
(error.message?.includes('"type":"overloaded_error"') ?? false)
|
||||
)
|
||||
}
|
||||
|
||||
function isOAuthTokenRevokedError(error: unknown): boolean {
|
||||
return (
|
||||
error instanceof APIError &&
|
||||
error.status === 403 &&
|
||||
(error.message?.includes('OAuth token has been revoked') ?? false)
|
||||
)
|
||||
}
|
||||
|
||||
function isBedrockAuthError(error: unknown): boolean {
|
||||
if (isEnvTruthy(process.env.CLAUDE_CODE_USE_BEDROCK)) {
|
||||
// AWS libs reject without an API call if .aws holds a past Expiration value
|
||||
// otherwise, API calls that receive expired tokens give generic 403
|
||||
// "The security token included in the request is invalid"
|
||||
if (
|
||||
isAwsCredentialsProviderError(error) ||
|
||||
(error instanceof APIError && error.status === 403)
|
||||
) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear AWS auth caches if appropriate.
|
||||
* @returns true if action was taken.
|
||||
*/
|
||||
function handleAwsCredentialError(error: unknown): boolean {
|
||||
if (isBedrockAuthError(error)) {
|
||||
clearAwsCredentialsCache()
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// google-auth-library throws plain Error (no typed name like AWS's
|
||||
// CredentialsProviderError). Match common SDK-level credential-failure messages.
|
||||
function isGoogleAuthLibraryCredentialError(error: unknown): boolean {
|
||||
if (!(error instanceof Error)) return false
|
||||
const msg = error.message
|
||||
return (
|
||||
msg.includes('Could not load the default credentials') ||
|
||||
msg.includes('Could not refresh access token') ||
|
||||
msg.includes('invalid_grant')
|
||||
)
|
||||
}
|
||||
|
||||
function isVertexAuthError(error: unknown): boolean {
|
||||
if (isEnvTruthy(process.env.CLAUDE_CODE_USE_VERTEX)) {
|
||||
// SDK-level: google-auth-library fails in prepareOptions() before the HTTP call
|
||||
if (isGoogleAuthLibraryCredentialError(error)) {
|
||||
return true
|
||||
}
|
||||
// Server-side: Vertex returns 401 for expired/invalid tokens
|
||||
if (error instanceof APIError && error.status === 401) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear GCP auth caches if appropriate.
|
||||
* @returns true if action was taken.
|
||||
*/
|
||||
function handleGcpCredentialError(error: unknown): boolean {
|
||||
if (isVertexAuthError(error)) {
|
||||
clearGcpCredentialsCache()
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
function shouldRetry(error: APIError): boolean {
|
||||
// Never retry mock errors - they're from /mock-limits command for testing
|
||||
if (isMockRateLimitError(error)) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Persistent mode: 429/529 always retryable, bypass subscriber gates and
|
||||
// x-should-retry header.
|
||||
if (isPersistentRetryEnabled() && isTransientCapacityError(error)) {
|
||||
return true
|
||||
}
|
||||
|
||||
// CCR mode: auth is via infrastructure-provided JWTs, so a 401/403 is a
|
||||
// transient blip (auth service flap, network hiccup) rather than bad
|
||||
// credentials. Bypass x-should-retry:false — the server assumes we'd retry
|
||||
// the same bad key, but our key is fine.
|
||||
if (
|
||||
isEnvTruthy(process.env.CLAUDE_CODE_REMOTE) &&
|
||||
(error.status === 401 || error.status === 403)
|
||||
) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for overloaded errors first by examining the message content
|
||||
// The SDK sometimes fails to properly pass the 529 status code during streaming,
|
||||
// so we need to check the error message directly
|
||||
if (error.message?.includes('"type":"overloaded_error"')) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for max tokens context overflow errors that we can handle
|
||||
if (parseMaxTokensContextOverflowError(error)) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Note this is not a standard header.
|
||||
const shouldRetryHeader = error.headers?.get('x-should-retry')
|
||||
|
||||
// If the server explicitly says whether or not to retry, obey.
|
||||
// For Max and Pro users, should-retry is true, but in several hours, so we shouldn't.
|
||||
// Enterprise users can retry because they typically use PAYG instead of rate limits.
|
||||
if (
|
||||
shouldRetryHeader === 'true' &&
|
||||
(!isClaudeAISubscriber() || isEnterpriseSubscriber())
|
||||
) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Ants can ignore x-should-retry: false for 5xx server errors only.
|
||||
// For other status codes (401, 403, 400, 429, etc.), respect the header.
|
||||
if (shouldRetryHeader === 'false') {
|
||||
const is5xxError = error.status !== undefined && error.status >= 500
|
||||
if (!(process.env.USER_TYPE === 'ant' && is5xxError)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if (error instanceof APIConnectionError) {
|
||||
return true
|
||||
}
|
||||
|
||||
if (!error.status) return false
|
||||
|
||||
// Retry on request timeouts.
|
||||
if (error.status === 408) return true
|
||||
|
||||
// Retry on lock timeouts.
|
||||
if (error.status === 409) return true
|
||||
|
||||
// Retry on rate limits, but not for ClaudeAI Subscription users
|
||||
// Enterprise users can retry because they typically use PAYG instead of rate limits
|
||||
if (error.status === 429) {
|
||||
return !isClaudeAISubscriber() || isEnterpriseSubscriber()
|
||||
}
|
||||
|
||||
// Clear API key cache on 401 and allow retry.
|
||||
// OAuth token handling is done in the main retry loop via handleOAuth401Error.
|
||||
if (error.status === 401) {
|
||||
clearApiKeyHelperCache()
|
||||
return true
|
||||
}
|
||||
|
||||
// Retry on 403 "token revoked" (same refresh logic as 401, see above)
|
||||
if (isOAuthTokenRevokedError(error)) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Retry internal errors.
|
||||
if (error.status && error.status >= 500) return true
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
export function getDefaultMaxRetries(): number {
|
||||
if (process.env.CLAUDE_CODE_MAX_RETRIES) {
|
||||
return parseInt(process.env.CLAUDE_CODE_MAX_RETRIES, 10)
|
||||
}
|
||||
return DEFAULT_MAX_RETRIES
|
||||
}
|
||||
function getMaxRetries(options: RetryOptions): number {
|
||||
return options.maxRetries ?? getDefaultMaxRetries()
|
||||
}
|
||||
|
||||
const DEFAULT_FAST_MODE_FALLBACK_HOLD_MS = 30 * 60 * 1000 // 30 minutes
|
||||
const SHORT_RETRY_THRESHOLD_MS = 20 * 1000 // 20 seconds
|
||||
const MIN_COOLDOWN_MS = 10 * 60 * 1000 // 10 minutes
|
||||
|
||||
function getRetryAfterMs(error: APIError): number | null {
|
||||
const retryAfter = getRetryAfter(error)
|
||||
if (retryAfter) {
|
||||
const seconds = parseInt(retryAfter, 10)
|
||||
if (!isNaN(seconds)) {
|
||||
return seconds * 1000
|
||||
}
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
function getRateLimitResetDelayMs(error: APIError): number | null {
|
||||
const resetHeader = error.headers?.get?.('anthropic-ratelimit-unified-reset')
|
||||
if (!resetHeader) return null
|
||||
const resetUnixSec = Number(resetHeader)
|
||||
if (!Number.isFinite(resetUnixSec)) return null
|
||||
const delayMs = resetUnixSec * 1000 - Date.now()
|
||||
if (delayMs <= 0) return null
|
||||
return Math.min(delayMs, PERSISTENT_RESET_CAP_MS)
|
||||
}
|
||||
324
src/services/autoDream/autoDream.ts
Normal file
324
src/services/autoDream/autoDream.ts
Normal file
@@ -0,0 +1,324 @@
|
||||
// biome-ignore-all assist/source/organizeImports: ANT-ONLY import markers must not be reordered
|
||||
// Background memory consolidation. Fires the /dream prompt as a forked
|
||||
// subagent when time-gate passes AND enough sessions have accumulated.
|
||||
//
|
||||
// Gate order (cheapest first):
|
||||
// 1. Time: hours since lastConsolidatedAt >= minHours (one stat)
|
||||
// 2. Sessions: transcript count with mtime > lastConsolidatedAt >= minSessions
|
||||
// 3. Lock: no other process mid-consolidation
|
||||
//
|
||||
// State is closure-scoped inside initAutoDream() rather than module-level
|
||||
// (tests call initAutoDream() in beforeEach for a fresh closure).
|
||||
|
||||
import type { REPLHookContext } from '../../utils/hooks/postSamplingHooks.js'
|
||||
import {
|
||||
createCacheSafeParams,
|
||||
runForkedAgent,
|
||||
} from '../../utils/forkedAgent.js'
|
||||
import {
|
||||
createUserMessage,
|
||||
createMemorySavedMessage,
|
||||
} from '../../utils/messages.js'
|
||||
import type { Message } from '../../types/message.js'
|
||||
import { logForDebugging } from '../../utils/debug.js'
|
||||
import type { ToolUseContext } from '../../Tool.js'
|
||||
import { logEvent } from '../analytics/index.js'
|
||||
import { getFeatureValue_CACHED_MAY_BE_STALE } from '../analytics/growthbook.js'
|
||||
import { isAutoMemoryEnabled, getAutoMemPath } from '../../memdir/paths.js'
|
||||
import { isAutoDreamEnabled } from './config.js'
|
||||
import { getProjectDir } from '../../utils/sessionStorage.js'
|
||||
import {
|
||||
getOriginalCwd,
|
||||
getKairosActive,
|
||||
getIsRemoteMode,
|
||||
getSessionId,
|
||||
} from '../../bootstrap/state.js'
|
||||
import { createAutoMemCanUseTool } from '../extractMemories/extractMemories.js'
|
||||
import { buildConsolidationPrompt } from './consolidationPrompt.js'
|
||||
import {
|
||||
readLastConsolidatedAt,
|
||||
listSessionsTouchedSince,
|
||||
tryAcquireConsolidationLock,
|
||||
rollbackConsolidationLock,
|
||||
} from './consolidationLock.js'
|
||||
import {
|
||||
registerDreamTask,
|
||||
addDreamTurn,
|
||||
completeDreamTask,
|
||||
failDreamTask,
|
||||
isDreamTask,
|
||||
} from '../../tasks/DreamTask/DreamTask.js'
|
||||
import { FILE_EDIT_TOOL_NAME } from '../../tools/FileEditTool/constants.js'
|
||||
import { FILE_WRITE_TOOL_NAME } from '../../tools/FileWriteTool/prompt.js'
|
||||
|
||||
// Scan throttle: when time-gate passes but session-gate doesn't, the lock
|
||||
// mtime doesn't advance, so the time-gate keeps passing every turn.
|
||||
const SESSION_SCAN_INTERVAL_MS = 10 * 60 * 1000
|
||||
|
||||
type AutoDreamConfig = {
|
||||
minHours: number
|
||||
minSessions: number
|
||||
}
|
||||
|
||||
const DEFAULTS: AutoDreamConfig = {
|
||||
minHours: 24,
|
||||
minSessions: 5,
|
||||
}
|
||||
|
||||
/**
|
||||
* Thresholds from tengu_onyx_plover. The enabled gate lives in config.ts
|
||||
* (isAutoDreamEnabled); this returns only the scheduling knobs. Defensive
|
||||
* per-field validation since GB cache can return stale wrong-type values.
|
||||
*/
|
||||
function getConfig(): AutoDreamConfig {
|
||||
const raw =
|
||||
getFeatureValue_CACHED_MAY_BE_STALE<Partial<AutoDreamConfig> | null>(
|
||||
'tengu_onyx_plover',
|
||||
null,
|
||||
)
|
||||
return {
|
||||
minHours:
|
||||
typeof raw?.minHours === 'number' &&
|
||||
Number.isFinite(raw.minHours) &&
|
||||
raw.minHours > 0
|
||||
? raw.minHours
|
||||
: DEFAULTS.minHours,
|
||||
minSessions:
|
||||
typeof raw?.minSessions === 'number' &&
|
||||
Number.isFinite(raw.minSessions) &&
|
||||
raw.minSessions > 0
|
||||
? raw.minSessions
|
||||
: DEFAULTS.minSessions,
|
||||
}
|
||||
}
|
||||
|
||||
function isGateOpen(): boolean {
|
||||
if (getKairosActive()) return false // KAIROS mode uses disk-skill dream
|
||||
if (getIsRemoteMode()) return false
|
||||
if (!isAutoMemoryEnabled()) return false
|
||||
return isAutoDreamEnabled()
|
||||
}
|
||||
|
||||
// Ant-build-only test override. Bypasses enabled/time/session gates but NOT
|
||||
// the lock (so repeated turns don't pile up dreams) or the memory-dir
|
||||
// precondition. Still scans sessions so the prompt's session-hint is populated.
|
||||
function isForced(): boolean {
|
||||
return false
|
||||
}
|
||||
|
||||
type AppendSystemMessageFn = NonNullable<ToolUseContext['appendSystemMessage']>
|
||||
|
||||
let runner:
|
||||
| ((
|
||||
context: REPLHookContext,
|
||||
appendSystemMessage?: AppendSystemMessageFn,
|
||||
) => Promise<void>)
|
||||
| null = null
|
||||
|
||||
/**
|
||||
* Call once at startup (from backgroundHousekeeping alongside
|
||||
* initExtractMemories), or per-test in beforeEach for a fresh closure.
|
||||
*/
|
||||
export function initAutoDream(): void {
|
||||
let lastSessionScanAt = 0
|
||||
|
||||
runner = async function runAutoDream(context, appendSystemMessage) {
|
||||
const cfg = getConfig()
|
||||
const force = isForced()
|
||||
if (!force && !isGateOpen()) return
|
||||
|
||||
// --- Time gate ---
|
||||
let lastAt: number
|
||||
try {
|
||||
lastAt = await readLastConsolidatedAt()
|
||||
} catch (e: unknown) {
|
||||
logForDebugging(
|
||||
`[autoDream] readLastConsolidatedAt failed: ${(e as Error).message}`,
|
||||
)
|
||||
return
|
||||
}
|
||||
const hoursSince = (Date.now() - lastAt) / 3_600_000
|
||||
if (!force && hoursSince < cfg.minHours) return
|
||||
|
||||
// --- Scan throttle ---
|
||||
const sinceScanMs = Date.now() - lastSessionScanAt
|
||||
if (!force && sinceScanMs < SESSION_SCAN_INTERVAL_MS) {
|
||||
logForDebugging(
|
||||
`[autoDream] scan throttle — time-gate passed but last scan was ${Math.round(sinceScanMs / 1000)}s ago`,
|
||||
)
|
||||
return
|
||||
}
|
||||
lastSessionScanAt = Date.now()
|
||||
|
||||
// --- Session gate ---
|
||||
let sessionIds: string[]
|
||||
try {
|
||||
sessionIds = await listSessionsTouchedSince(lastAt)
|
||||
} catch (e: unknown) {
|
||||
logForDebugging(
|
||||
`[autoDream] listSessionsTouchedSince failed: ${(e as Error).message}`,
|
||||
)
|
||||
return
|
||||
}
|
||||
// Exclude the current session (its mtime is always recent).
|
||||
const currentSession = getSessionId()
|
||||
sessionIds = sessionIds.filter(id => id !== currentSession)
|
||||
if (!force && sessionIds.length < cfg.minSessions) {
|
||||
logForDebugging(
|
||||
`[autoDream] skip — ${sessionIds.length} sessions since last consolidation, need ${cfg.minSessions}`,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// --- Lock ---
|
||||
// Under force, skip acquire entirely — use the existing mtime so
|
||||
// kill's rollback is a no-op (rewinds to where it already is).
|
||||
// The lock file stays untouched; next non-force turn sees it as-is.
|
||||
let priorMtime: number | null
|
||||
if (force) {
|
||||
priorMtime = lastAt
|
||||
} else {
|
||||
try {
|
||||
priorMtime = await tryAcquireConsolidationLock()
|
||||
} catch (e: unknown) {
|
||||
logForDebugging(
|
||||
`[autoDream] lock acquire failed: ${(e as Error).message}`,
|
||||
)
|
||||
return
|
||||
}
|
||||
if (priorMtime === null) return
|
||||
}
|
||||
|
||||
logForDebugging(
|
||||
`[autoDream] firing — ${hoursSince.toFixed(1)}h since last, ${sessionIds.length} sessions to review`,
|
||||
)
|
||||
logEvent('tengu_auto_dream_fired', {
|
||||
hours_since: Math.round(hoursSince),
|
||||
sessions_since: sessionIds.length,
|
||||
})
|
||||
|
||||
const setAppState =
|
||||
context.toolUseContext.setAppStateForTasks ??
|
||||
context.toolUseContext.setAppState
|
||||
const abortController = new AbortController()
|
||||
const taskId = registerDreamTask(setAppState, {
|
||||
sessionsReviewing: sessionIds.length,
|
||||
priorMtime,
|
||||
abortController,
|
||||
})
|
||||
|
||||
try {
|
||||
const memoryRoot = getAutoMemPath()
|
||||
const transcriptDir = getProjectDir(getOriginalCwd())
|
||||
// Tool constraints note goes in `extra`, not the shared prompt body —
|
||||
// manual /dream runs in the main loop with normal permissions and this
|
||||
// would be misleading there.
|
||||
const extra = `
|
||||
|
||||
**Tool constraints for this run:** Bash is restricted to read-only commands (\`ls\`, \`find\`, \`grep\`, \`cat\`, \`stat\`, \`wc\`, \`head\`, \`tail\`, and similar). Anything that writes, redirects to a file, or modifies state will be denied. Plan your exploration with this in mind — no need to probe.
|
||||
|
||||
Sessions since last consolidation (${sessionIds.length}):
|
||||
${sessionIds.map(id => `- ${id}`).join('\n')}`
|
||||
const prompt = buildConsolidationPrompt(memoryRoot, transcriptDir, extra)
|
||||
|
||||
const result = await runForkedAgent({
|
||||
promptMessages: [createUserMessage({ content: prompt })],
|
||||
cacheSafeParams: createCacheSafeParams(context),
|
||||
canUseTool: createAutoMemCanUseTool(memoryRoot),
|
||||
querySource: 'auto_dream',
|
||||
forkLabel: 'auto_dream',
|
||||
skipTranscript: true,
|
||||
overrides: { abortController },
|
||||
onMessage: makeDreamProgressWatcher(taskId, setAppState),
|
||||
})
|
||||
|
||||
completeDreamTask(taskId, setAppState)
|
||||
// Inline completion summary in the main transcript (same surface as
|
||||
// extractMemories's "Saved N memories" message).
|
||||
const dreamState = context.toolUseContext.getAppState().tasks?.[taskId]
|
||||
if (
|
||||
appendSystemMessage &&
|
||||
isDreamTask(dreamState) &&
|
||||
dreamState.filesTouched.length > 0
|
||||
) {
|
||||
appendSystemMessage({
|
||||
...createMemorySavedMessage(dreamState.filesTouched),
|
||||
verb: 'Improved',
|
||||
})
|
||||
}
|
||||
logForDebugging(
|
||||
`[autoDream] completed — cache: read=${result.totalUsage.cache_read_input_tokens} created=${result.totalUsage.cache_creation_input_tokens}`,
|
||||
)
|
||||
logEvent('tengu_auto_dream_completed', {
|
||||
cache_read: result.totalUsage.cache_read_input_tokens,
|
||||
cache_created: result.totalUsage.cache_creation_input_tokens,
|
||||
output: result.totalUsage.output_tokens,
|
||||
sessions_reviewed: sessionIds.length,
|
||||
})
|
||||
} catch (e: unknown) {
|
||||
// If the user killed from the bg-tasks dialog, DreamTask.kill already
|
||||
// aborted, rolled back the lock, and set status=killed. Don't overwrite
|
||||
// or double-rollback.
|
||||
if (abortController.signal.aborted) {
|
||||
logForDebugging('[autoDream] aborted by user')
|
||||
return
|
||||
}
|
||||
logForDebugging(`[autoDream] fork failed: ${(e as Error).message}`)
|
||||
logEvent('tengu_auto_dream_failed', {})
|
||||
failDreamTask(taskId, setAppState)
|
||||
// Rewind mtime so time-gate passes again. Scan throttle is the backoff.
|
||||
await rollbackConsolidationLock(priorMtime)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Watch the forked agent's messages. For each assistant turn, extracts any
|
||||
* text blocks (the agent's reasoning/summary — what the user wants to see)
|
||||
* and collapses tool_use blocks to a count. Edit/Write file_paths are
|
||||
* collected for phase-flip + the inline completion message.
|
||||
*/
|
||||
function makeDreamProgressWatcher(
|
||||
taskId: string,
|
||||
setAppState: import('../../Task.js').SetAppState,
|
||||
): (msg: Message) => void {
|
||||
return msg => {
|
||||
if (msg.type !== 'assistant') return
|
||||
let text = ''
|
||||
let toolUseCount = 0
|
||||
const touchedPaths: string[] = []
|
||||
for (const block of msg.message.content) {
|
||||
if (block.type === 'text') {
|
||||
text += block.text
|
||||
} else if (block.type === 'tool_use') {
|
||||
toolUseCount++
|
||||
if (
|
||||
block.name === FILE_EDIT_TOOL_NAME ||
|
||||
block.name === FILE_WRITE_TOOL_NAME
|
||||
) {
|
||||
const input = block.input as { file_path?: unknown }
|
||||
if (typeof input.file_path === 'string') {
|
||||
touchedPaths.push(input.file_path)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
addDreamTurn(
|
||||
taskId,
|
||||
{ text: text.trim(), toolUseCount },
|
||||
touchedPaths,
|
||||
setAppState,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Entry point from stopHooks. No-op until initAutoDream() has been called.
|
||||
* Per-turn cost when enabled: one GB cache read + one stat.
|
||||
*/
|
||||
export async function executeAutoDream(
|
||||
context: REPLHookContext,
|
||||
appendSystemMessage?: AppendSystemMessageFn,
|
||||
): Promise<void> {
|
||||
await runner?.(context, appendSystemMessage)
|
||||
}
|
||||
21
src/services/autoDream/config.ts
Normal file
21
src/services/autoDream/config.ts
Normal file
@@ -0,0 +1,21 @@
|
||||
// Leaf config module — intentionally minimal imports so UI components
|
||||
// can read the auto-dream enabled state without dragging in the forked
|
||||
// agent / task registry / message builder chain that autoDream.ts pulls in.
|
||||
|
||||
import { getInitialSettings } from '../../utils/settings/settings.js'
|
||||
import { getFeatureValue_CACHED_MAY_BE_STALE } from '../analytics/growthbook.js'
|
||||
|
||||
/**
|
||||
* Whether background memory consolidation should run. User setting
|
||||
* (autoDreamEnabled in settings.json) overrides the GrowthBook default
|
||||
* when explicitly set; otherwise falls through to tengu_onyx_plover.
|
||||
*/
|
||||
export function isAutoDreamEnabled(): boolean {
|
||||
const setting = getInitialSettings().autoDreamEnabled
|
||||
if (setting !== undefined) return setting
|
||||
const gb = getFeatureValue_CACHED_MAY_BE_STALE<{ enabled?: unknown } | null>(
|
||||
'tengu_onyx_plover',
|
||||
null,
|
||||
)
|
||||
return gb?.enabled === true
|
||||
}
|
||||
140
src/services/autoDream/consolidationLock.ts
Normal file
140
src/services/autoDream/consolidationLock.ts
Normal file
@@ -0,0 +1,140 @@
|
||||
// Lock file whose mtime IS lastConsolidatedAt. Body is the holder's PID.
|
||||
//
|
||||
// Lives inside the memory dir (getAutoMemPath) so it keys on git-root
|
||||
// like memory does, and so it's writable even when the memory path comes
|
||||
// from an env/settings override whose parent may not be.
|
||||
|
||||
import { mkdir, readFile, stat, unlink, utimes, writeFile } from 'fs/promises'
|
||||
import { join } from 'path'
|
||||
import { getOriginalCwd } from '../../bootstrap/state.js'
|
||||
import { getAutoMemPath } from '../../memdir/paths.js'
|
||||
import { logForDebugging } from '../../utils/debug.js'
|
||||
import { isProcessRunning } from '../../utils/genericProcessUtils.js'
|
||||
import { listCandidates } from '../../utils/listSessionsImpl.js'
|
||||
import { getProjectDir } from '../../utils/sessionStorage.js'
|
||||
|
||||
const LOCK_FILE = '.consolidate-lock'
|
||||
|
||||
// Stale past this even if the PID is live (PID reuse guard).
|
||||
const HOLDER_STALE_MS = 60 * 60 * 1000
|
||||
|
||||
function lockPath(): string {
|
||||
return join(getAutoMemPath(), LOCK_FILE)
|
||||
}
|
||||
|
||||
/**
|
||||
* mtime of the lock file = lastConsolidatedAt. 0 if absent.
|
||||
* Per-turn cost: one stat.
|
||||
*/
|
||||
export async function readLastConsolidatedAt(): Promise<number> {
|
||||
try {
|
||||
const s = await stat(lockPath())
|
||||
return s.mtimeMs
|
||||
} catch {
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Acquire: write PID → mtime = now. Returns the pre-acquire mtime
|
||||
* (for rollback), or null if blocked / lost a race.
|
||||
*
|
||||
* Success → do nothing. mtime stays at now.
|
||||
* Failure → rollbackConsolidationLock(priorMtime) rewinds mtime.
|
||||
* Crash → mtime stuck, dead PID → next process reclaims.
|
||||
*/
|
||||
export async function tryAcquireConsolidationLock(): Promise<number | null> {
|
||||
const path = lockPath()
|
||||
|
||||
let mtimeMs: number | undefined
|
||||
let holderPid: number | undefined
|
||||
try {
|
||||
const [s, raw] = await Promise.all([stat(path), readFile(path, 'utf8')])
|
||||
mtimeMs = s.mtimeMs
|
||||
const parsed = parseInt(raw.trim(), 10)
|
||||
holderPid = Number.isFinite(parsed) ? parsed : undefined
|
||||
} catch {
|
||||
// ENOENT — no prior lock.
|
||||
}
|
||||
|
||||
if (mtimeMs !== undefined && Date.now() - mtimeMs < HOLDER_STALE_MS) {
|
||||
if (holderPid !== undefined && isProcessRunning(holderPid)) {
|
||||
logForDebugging(
|
||||
`[autoDream] lock held by live PID ${holderPid} (mtime ${Math.round((Date.now() - mtimeMs) / 1000)}s ago)`,
|
||||
)
|
||||
return null
|
||||
}
|
||||
// Dead PID or unparseable body — reclaim.
|
||||
}
|
||||
|
||||
// Memory dir may not exist yet.
|
||||
await mkdir(getAutoMemPath(), { recursive: true })
|
||||
await writeFile(path, String(process.pid))
|
||||
|
||||
// Two reclaimers both write → last wins the PID. Loser bails on re-read.
|
||||
let verify: string
|
||||
try {
|
||||
verify = await readFile(path, 'utf8')
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
if (parseInt(verify.trim(), 10) !== process.pid) return null
|
||||
|
||||
return mtimeMs ?? 0
|
||||
}
|
||||
|
||||
/**
|
||||
* Rewind mtime to pre-acquire after a failed fork. Clears the PID body —
|
||||
* otherwise our still-running process would look like it's holding.
|
||||
* priorMtime 0 → unlink (restore no-file).
|
||||
*/
|
||||
export async function rollbackConsolidationLock(
|
||||
priorMtime: number,
|
||||
): Promise<void> {
|
||||
const path = lockPath()
|
||||
try {
|
||||
if (priorMtime === 0) {
|
||||
await unlink(path)
|
||||
return
|
||||
}
|
||||
await writeFile(path, '')
|
||||
const t = priorMtime / 1000 // utimes wants seconds
|
||||
await utimes(path, t, t)
|
||||
} catch (e: unknown) {
|
||||
logForDebugging(
|
||||
`[autoDream] rollback failed: ${(e as Error).message} — next trigger delayed to minHours`,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Session IDs with mtime after sinceMs. listCandidates handles UUID
|
||||
* validation (excludes agent-*.jsonl) and parallel stat.
|
||||
*
|
||||
* Uses mtime (sessions TOUCHED since), not birthtime (0 on ext4).
|
||||
* Caller excludes the current session. Scans per-cwd transcripts — it's
|
||||
* a skip-gate, so undercounting worktree sessions is safe.
|
||||
*/
|
||||
export async function listSessionsTouchedSince(
|
||||
sinceMs: number,
|
||||
): Promise<string[]> {
|
||||
const dir = getProjectDir(getOriginalCwd())
|
||||
const candidates = await listCandidates(dir, true)
|
||||
return candidates.filter(c => c.mtime > sinceMs).map(c => c.sessionId)
|
||||
}
|
||||
|
||||
/**
|
||||
* Stamp from manual /dream. Optimistic — fires at prompt-build time,
|
||||
* no post-skill completion hook. Best-effort.
|
||||
*/
|
||||
export async function recordConsolidation(): Promise<void> {
|
||||
try {
|
||||
// Memory dir may not exist yet (manual /dream before any auto-trigger).
|
||||
await mkdir(getAutoMemPath(), { recursive: true })
|
||||
await writeFile(lockPath(), String(process.pid))
|
||||
} catch (e: unknown) {
|
||||
logForDebugging(
|
||||
`[autoDream] recordConsolidation write failed: ${(e as Error).message}`,
|
||||
)
|
||||
}
|
||||
}
|
||||
65
src/services/autoDream/consolidationPrompt.ts
Normal file
65
src/services/autoDream/consolidationPrompt.ts
Normal file
@@ -0,0 +1,65 @@
|
||||
// Extracted from dream.ts so auto-dream ships independently of KAIROS
|
||||
// feature flags (dream.ts is behind a feature()-gated require).
|
||||
|
||||
import {
|
||||
DIR_EXISTS_GUIDANCE,
|
||||
ENTRYPOINT_NAME,
|
||||
MAX_ENTRYPOINT_LINES,
|
||||
} from '../../memdir/memdir.js'
|
||||
|
||||
export function buildConsolidationPrompt(
|
||||
memoryRoot: string,
|
||||
transcriptDir: string,
|
||||
extra: string,
|
||||
): string {
|
||||
return `# Dream: Memory Consolidation
|
||||
|
||||
You are performing a dream — a reflective pass over your memory files. Synthesize what you've learned recently into durable, well-organized memories so that future sessions can orient quickly.
|
||||
|
||||
Memory directory: \`${memoryRoot}\`
|
||||
${DIR_EXISTS_GUIDANCE}
|
||||
|
||||
Session transcripts: \`${transcriptDir}\` (large JSONL files — grep narrowly, don't read whole files)
|
||||
|
||||
---
|
||||
|
||||
## Phase 1 — Orient
|
||||
|
||||
- \`ls\` the memory directory to see what already exists
|
||||
- Read \`${ENTRYPOINT_NAME}\` to understand the current index
|
||||
- Skim existing topic files so you improve them rather than creating duplicates
|
||||
- If \`logs/\` or \`sessions/\` subdirectories exist (assistant-mode layout), review recent entries there
|
||||
|
||||
## Phase 2 — Gather recent signal
|
||||
|
||||
Look for new information worth persisting. Sources in rough priority order:
|
||||
|
||||
1. **Daily logs** (\`logs/YYYY/MM/YYYY-MM-DD.md\`) if present — these are the append-only stream
|
||||
2. **Existing memories that drifted** — facts that contradict something you see in the codebase now
|
||||
3. **Transcript search** — if you need specific context (e.g., "what was the error message from yesterday's build failure?"), grep the JSONL transcripts for narrow terms:
|
||||
\`grep -rn "<narrow term>" ${transcriptDir}/ --include="*.jsonl" | tail -50\`
|
||||
|
||||
Don't exhaustively read transcripts. Look only for things you already suspect matter.
|
||||
|
||||
## Phase 3 — Consolidate
|
||||
|
||||
For each thing worth remembering, write or update a memory file at the top level of the memory directory. Use the memory file format and type conventions from your system prompt's auto-memory section — it's the source of truth for what to save, how to structure it, and what NOT to save.
|
||||
|
||||
Focus on:
|
||||
- Merging new signal into existing topic files rather than creating near-duplicates
|
||||
- Converting relative dates ("yesterday", "last week") to absolute dates so they remain interpretable after time passes
|
||||
- Deleting contradicted facts — if today's investigation disproves an old memory, fix it at the source
|
||||
|
||||
## Phase 4 — Prune and index
|
||||
|
||||
Update \`${ENTRYPOINT_NAME}\` so it stays under ${MAX_ENTRYPOINT_LINES} lines AND under ~25KB. It's an **index**, not a dump — each entry should be one line under ~150 characters: \`- [Title](file.md) — one-line hook\`. Never write memory content directly into it.
|
||||
|
||||
- Remove pointers to memories that are now stale, wrong, or superseded
|
||||
- Demote verbose entries: if an index line is over ~200 chars, it's carrying content that belongs in the topic file — shorten the line, move the detail
|
||||
- Add pointers to newly important memories
|
||||
- Resolve contradictions — if two files disagree, fix the wrong one
|
||||
|
||||
---
|
||||
|
||||
Return a brief summary of what you consolidated, updated, or pruned. If nothing changed (memories are already tight), say so.${extra ? `\n\n## Additional context\n\n${extra}` : ''}`
|
||||
}
|
||||
74
src/services/awaySummary.ts
Normal file
74
src/services/awaySummary.ts
Normal file
@@ -0,0 +1,74 @@
|
||||
import { APIUserAbortError } from '@anthropic-ai/sdk'
|
||||
import { getEmptyToolPermissionContext } from '../Tool.js'
|
||||
import type { Message } from '../types/message.js'
|
||||
import { logForDebugging } from '../utils/debug.js'
|
||||
import {
|
||||
createUserMessage,
|
||||
getAssistantMessageText,
|
||||
} from '../utils/messages.js'
|
||||
import { getSmallFastModel } from '../utils/model/model.js'
|
||||
import { asSystemPrompt } from '../utils/systemPromptType.js'
|
||||
import { queryModelWithoutStreaming } from './api/claude.js'
|
||||
import { getSessionMemoryContent } from './SessionMemory/sessionMemoryUtils.js'
|
||||
|
||||
// Recap only needs recent context — truncate to avoid "prompt too long" on
|
||||
// large sessions. 30 messages ≈ ~15 exchanges, plenty for "where we left off."
|
||||
const RECENT_MESSAGE_WINDOW = 30
|
||||
|
||||
function buildAwaySummaryPrompt(memory: string | null): string {
|
||||
const memoryBlock = memory
|
||||
? `Session memory (broader context):\n${memory}\n\n`
|
||||
: ''
|
||||
return `${memoryBlock}The user stepped away and is coming back. Write exactly 1-3 short sentences. Start by stating the high-level task — what they are building or debugging, not implementation details. Next: the concrete next step. Skip status reports and commit recaps.`
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a short session recap for the "while you were away" card.
|
||||
* Returns null on abort, empty transcript, or error.
|
||||
*/
|
||||
export async function generateAwaySummary(
|
||||
messages: readonly Message[],
|
||||
signal: AbortSignal,
|
||||
): Promise<string | null> {
|
||||
if (messages.length === 0) {
|
||||
return null
|
||||
}
|
||||
|
||||
try {
|
||||
const memory = await getSessionMemoryContent()
|
||||
const recent = messages.slice(-RECENT_MESSAGE_WINDOW)
|
||||
recent.push(createUserMessage({ content: buildAwaySummaryPrompt(memory) }))
|
||||
const response = await queryModelWithoutStreaming({
|
||||
messages: recent,
|
||||
systemPrompt: asSystemPrompt([]),
|
||||
thinkingConfig: { type: 'disabled' },
|
||||
tools: [],
|
||||
signal,
|
||||
options: {
|
||||
getToolPermissionContext: async () => getEmptyToolPermissionContext(),
|
||||
model: getSmallFastModel(),
|
||||
toolChoice: undefined,
|
||||
isNonInteractiveSession: false,
|
||||
hasAppendSystemPrompt: false,
|
||||
agents: [],
|
||||
querySource: 'away_summary',
|
||||
mcpTools: [],
|
||||
skipCacheWrite: true,
|
||||
},
|
||||
})
|
||||
|
||||
if (response.isApiErrorMessage) {
|
||||
logForDebugging(
|
||||
`[awaySummary] API error: ${getAssistantMessageText(response)}`,
|
||||
)
|
||||
return null
|
||||
}
|
||||
return getAssistantMessageText(response)
|
||||
} catch (err) {
|
||||
if (err instanceof APIUserAbortError || signal.aborted) {
|
||||
return null
|
||||
}
|
||||
logForDebugging(`[awaySummary] generation failed: ${err}`)
|
||||
return null
|
||||
}
|
||||
}
|
||||
515
src/services/claudeAiLimits.ts
Normal file
515
src/services/claudeAiLimits.ts
Normal file
@@ -0,0 +1,515 @@
|
||||
import { APIError } from '@anthropic-ai/sdk'
|
||||
import type { MessageParam } from '@anthropic-ai/sdk/resources/index.mjs'
|
||||
import isEqual from 'lodash-es/isEqual.js'
|
||||
import { getIsNonInteractiveSession } from '../bootstrap/state.js'
|
||||
import { isClaudeAISubscriber } from '../utils/auth.js'
|
||||
import { getModelBetas } from '../utils/betas.js'
|
||||
import { getGlobalConfig, saveGlobalConfig } from '../utils/config.js'
|
||||
import { logError } from '../utils/log.js'
|
||||
import { getSmallFastModel } from '../utils/model/model.js'
|
||||
import { isEssentialTrafficOnly } from '../utils/privacyLevel.js'
|
||||
import type { AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS } from './analytics/index.js'
|
||||
import { logEvent } from './analytics/index.js'
|
||||
import { getAPIMetadata } from './api/claude.js'
|
||||
import { getAnthropicClient } from './api/client.js'
|
||||
import {
|
||||
processRateLimitHeaders,
|
||||
shouldProcessRateLimits,
|
||||
} from './rateLimitMocking.js'
|
||||
|
||||
// Re-export message functions from centralized location
|
||||
export {
|
||||
getRateLimitErrorMessage,
|
||||
getRateLimitWarning,
|
||||
getUsingOverageText,
|
||||
} from './rateLimitMessages.js'
|
||||
|
||||
type QuotaStatus = 'allowed' | 'allowed_warning' | 'rejected'
|
||||
|
||||
type RateLimitType =
|
||||
| 'five_hour'
|
||||
| 'seven_day'
|
||||
| 'seven_day_opus'
|
||||
| 'seven_day_sonnet'
|
||||
| 'overage'
|
||||
|
||||
export type { RateLimitType }
|
||||
|
||||
type EarlyWarningThreshold = {
|
||||
utilization: number // 0-1 scale: trigger warning when usage >= this
|
||||
timePct: number // 0-1 scale: trigger warning when time elapsed <= this
|
||||
}
|
||||
|
||||
type EarlyWarningConfig = {
|
||||
rateLimitType: RateLimitType
|
||||
claimAbbrev: '5h' | '7d'
|
||||
windowSeconds: number
|
||||
thresholds: EarlyWarningThreshold[]
|
||||
}
|
||||
|
||||
// Early warning configurations in priority order (checked first to last)
|
||||
// Used as fallback when server doesn't send surpassed-threshold header
|
||||
// Warns users when they're consuming quota faster than the time window allows
|
||||
const EARLY_WARNING_CONFIGS: EarlyWarningConfig[] = [
|
||||
{
|
||||
rateLimitType: 'five_hour',
|
||||
claimAbbrev: '5h',
|
||||
windowSeconds: 5 * 60 * 60,
|
||||
thresholds: [{ utilization: 0.9, timePct: 0.72 }],
|
||||
},
|
||||
{
|
||||
rateLimitType: 'seven_day',
|
||||
claimAbbrev: '7d',
|
||||
windowSeconds: 7 * 24 * 60 * 60,
|
||||
thresholds: [
|
||||
{ utilization: 0.75, timePct: 0.6 },
|
||||
{ utilization: 0.5, timePct: 0.35 },
|
||||
{ utilization: 0.25, timePct: 0.15 },
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
// Maps claim abbreviations to rate limit types for header-based detection
|
||||
const EARLY_WARNING_CLAIM_MAP: Record<string, RateLimitType> = {
|
||||
'5h': 'five_hour',
|
||||
'7d': 'seven_day',
|
||||
overage: 'overage',
|
||||
}
|
||||
|
||||
const RATE_LIMIT_DISPLAY_NAMES: Record<RateLimitType, string> = {
|
||||
five_hour: 'session limit',
|
||||
seven_day: 'weekly limit',
|
||||
seven_day_opus: 'Opus limit',
|
||||
seven_day_sonnet: 'Sonnet limit',
|
||||
overage: 'extra usage limit',
|
||||
}
|
||||
|
||||
export function getRateLimitDisplayName(type: RateLimitType): string {
|
||||
return RATE_LIMIT_DISPLAY_NAMES[type] || type
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate what fraction of a time window has elapsed.
|
||||
* Used for time-relative early warning fallback.
|
||||
* @param resetsAt - Unix epoch timestamp in seconds when the limit resets
|
||||
* @param windowSeconds - Duration of the window in seconds
|
||||
* @returns fraction (0-1) of the window that has elapsed
|
||||
*/
|
||||
function computeTimeProgress(resetsAt: number, windowSeconds: number): number {
|
||||
const nowSeconds = Date.now() / 1000
|
||||
const windowStart = resetsAt - windowSeconds
|
||||
const elapsed = nowSeconds - windowStart
|
||||
return Math.max(0, Math.min(1, elapsed / windowSeconds))
|
||||
}
|
||||
|
||||
// Reason why overage is disabled/rejected
|
||||
// These values come from the API's unified limiter
|
||||
export type OverageDisabledReason =
|
||||
| 'overage_not_provisioned' // Overage is not provisioned for this org or seat tier
|
||||
| 'org_level_disabled' // Organization doesn't have overage enabled
|
||||
| 'org_level_disabled_until' // Organization overage temporarily disabled
|
||||
| 'out_of_credits' // Organization has insufficient credits
|
||||
| 'seat_tier_level_disabled' // Seat tier doesn't have overage enabled
|
||||
| 'member_level_disabled' // Account specifically has overage disabled
|
||||
| 'seat_tier_zero_credit_limit' // Seat tier has a zero credit limit
|
||||
| 'group_zero_credit_limit' // Resolved group limit has a zero credit limit
|
||||
| 'member_zero_credit_limit' // Account has a zero credit limit
|
||||
| 'org_service_level_disabled' // Org service specifically has overage disabled
|
||||
| 'org_service_zero_credit_limit' // Org service has a zero credit limit
|
||||
| 'no_limits_configured' // No overage limits configured for account
|
||||
| 'unknown' // Unknown reason, should not happen
|
||||
|
||||
export type ClaudeAILimits = {
|
||||
status: QuotaStatus
|
||||
// unifiedRateLimitFallbackAvailable is currently used to warn users that set
|
||||
// their model to Opus whenever they are about to run out of quota. It does
|
||||
// not change the actual model that is used.
|
||||
unifiedRateLimitFallbackAvailable: boolean
|
||||
resetsAt?: number
|
||||
rateLimitType?: RateLimitType
|
||||
utilization?: number
|
||||
overageStatus?: QuotaStatus
|
||||
overageResetsAt?: number
|
||||
overageDisabledReason?: OverageDisabledReason
|
||||
isUsingOverage?: boolean
|
||||
surpassedThreshold?: number
|
||||
}
|
||||
|
||||
// Exported for testing only
|
||||
export let currentLimits: ClaudeAILimits = {
|
||||
status: 'allowed',
|
||||
unifiedRateLimitFallbackAvailable: false,
|
||||
isUsingOverage: false,
|
||||
}
|
||||
|
||||
/**
|
||||
* Raw per-window utilization from response headers, tracked on every API
|
||||
* response (unlike currentLimits.utilization which is only set when a warning
|
||||
* threshold fires). Exposed to statusline scripts via getRawUtilization().
|
||||
*/
|
||||
type RawWindowUtilization = {
|
||||
utilization: number // 0-1 fraction
|
||||
resets_at: number // unix epoch seconds
|
||||
}
|
||||
type RawUtilization = {
|
||||
five_hour?: RawWindowUtilization
|
||||
seven_day?: RawWindowUtilization
|
||||
}
|
||||
let rawUtilization: RawUtilization = {}
|
||||
|
||||
export function getRawUtilization(): RawUtilization {
|
||||
return rawUtilization
|
||||
}
|
||||
|
||||
function extractRawUtilization(headers: globalThis.Headers): RawUtilization {
|
||||
const result: RawUtilization = {}
|
||||
for (const [key, abbrev] of [
|
||||
['five_hour', '5h'],
|
||||
['seven_day', '7d'],
|
||||
] as const) {
|
||||
const util = headers.get(
|
||||
`anthropic-ratelimit-unified-${abbrev}-utilization`,
|
||||
)
|
||||
const reset = headers.get(`anthropic-ratelimit-unified-${abbrev}-reset`)
|
||||
if (util !== null && reset !== null) {
|
||||
result[key] = { utilization: Number(util), resets_at: Number(reset) }
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
type StatusChangeListener = (limits: ClaudeAILimits) => void
|
||||
export const statusListeners: Set<StatusChangeListener> = new Set()
|
||||
|
||||
export function emitStatusChange(limits: ClaudeAILimits) {
|
||||
currentLimits = limits
|
||||
statusListeners.forEach(listener => listener(limits))
|
||||
const hoursTillReset = Math.round(
|
||||
(limits.resetsAt ? limits.resetsAt - Date.now() / 1000 : 0) / (60 * 60),
|
||||
)
|
||||
|
||||
logEvent('tengu_claudeai_limits_status_changed', {
|
||||
status:
|
||||
limits.status as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
unifiedRateLimitFallbackAvailable: limits.unifiedRateLimitFallbackAvailable,
|
||||
hoursTillReset,
|
||||
})
|
||||
}
|
||||
|
||||
async function makeTestQuery() {
|
||||
const model = getSmallFastModel()
|
||||
const anthropic = await getAnthropicClient({
|
||||
maxRetries: 0,
|
||||
model,
|
||||
source: 'quota_check',
|
||||
})
|
||||
const messages: MessageParam[] = [{ role: 'user', content: 'quota' }]
|
||||
const betas = getModelBetas(model)
|
||||
// biome-ignore lint/plugin: quota check needs raw response access via asResponse()
|
||||
return anthropic.beta.messages
|
||||
.create({
|
||||
model,
|
||||
max_tokens: 1,
|
||||
messages,
|
||||
metadata: getAPIMetadata(),
|
||||
...(betas.length > 0 ? { betas } : {}),
|
||||
})
|
||||
.asResponse()
|
||||
}
|
||||
|
||||
export async function checkQuotaStatus(): Promise<void> {
|
||||
// Skip network requests if nonessential traffic is disabled
|
||||
if (isEssentialTrafficOnly()) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if we should process rate limits (real subscriber or mock testing)
|
||||
if (!shouldProcessRateLimits(isClaudeAISubscriber())) {
|
||||
return
|
||||
}
|
||||
|
||||
// In non-interactive mode (-p), the real query follows immediately and
|
||||
// extractQuotaStatusFromHeaders() will update limits from its response
|
||||
// headers (claude.ts), so skip this pre-check API call.
|
||||
if (getIsNonInteractiveSession()) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
// Make a minimal request to check quota
|
||||
const raw = await makeTestQuery()
|
||||
|
||||
// Update limits based on the response
|
||||
extractQuotaStatusFromHeaders(raw.headers)
|
||||
} catch (error) {
|
||||
if (error instanceof APIError) {
|
||||
extractQuotaStatusFromError(error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if early warning should be triggered based on surpassed-threshold header.
|
||||
* Returns ClaudeAILimits if a threshold was surpassed, null otherwise.
|
||||
*/
|
||||
function getHeaderBasedEarlyWarning(
|
||||
headers: globalThis.Headers,
|
||||
unifiedRateLimitFallbackAvailable: boolean,
|
||||
): ClaudeAILimits | null {
|
||||
// Check each claim type for surpassed threshold header
|
||||
for (const [claimAbbrev, rateLimitType] of Object.entries(
|
||||
EARLY_WARNING_CLAIM_MAP,
|
||||
)) {
|
||||
const surpassedThreshold = headers.get(
|
||||
`anthropic-ratelimit-unified-${claimAbbrev}-surpassed-threshold`,
|
||||
)
|
||||
|
||||
// If threshold header is present, user has crossed a warning threshold
|
||||
if (surpassedThreshold !== null) {
|
||||
const utilizationHeader = headers.get(
|
||||
`anthropic-ratelimit-unified-${claimAbbrev}-utilization`,
|
||||
)
|
||||
const resetHeader = headers.get(
|
||||
`anthropic-ratelimit-unified-${claimAbbrev}-reset`,
|
||||
)
|
||||
|
||||
const utilization = utilizationHeader
|
||||
? Number(utilizationHeader)
|
||||
: undefined
|
||||
const resetsAt = resetHeader ? Number(resetHeader) : undefined
|
||||
|
||||
return {
|
||||
status: 'allowed_warning',
|
||||
resetsAt,
|
||||
rateLimitType: rateLimitType as RateLimitType,
|
||||
utilization,
|
||||
unifiedRateLimitFallbackAvailable,
|
||||
isUsingOverage: false,
|
||||
surpassedThreshold: Number(surpassedThreshold),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if time-relative early warning should be triggered for a rate limit type.
|
||||
* Fallback when server doesn't send surpassed-threshold header.
|
||||
* Returns ClaudeAILimits if thresholds are exceeded, null otherwise.
|
||||
*/
|
||||
function getTimeRelativeEarlyWarning(
|
||||
headers: globalThis.Headers,
|
||||
config: EarlyWarningConfig,
|
||||
unifiedRateLimitFallbackAvailable: boolean,
|
||||
): ClaudeAILimits | null {
|
||||
const { rateLimitType, claimAbbrev, windowSeconds, thresholds } = config
|
||||
|
||||
const utilizationHeader = headers.get(
|
||||
`anthropic-ratelimit-unified-${claimAbbrev}-utilization`,
|
||||
)
|
||||
const resetHeader = headers.get(
|
||||
`anthropic-ratelimit-unified-${claimAbbrev}-reset`,
|
||||
)
|
||||
|
||||
if (utilizationHeader === null || resetHeader === null) {
|
||||
return null
|
||||
}
|
||||
|
||||
const utilization = Number(utilizationHeader)
|
||||
const resetsAt = Number(resetHeader)
|
||||
const timeProgress = computeTimeProgress(resetsAt, windowSeconds)
|
||||
|
||||
// Check if any threshold is exceeded: high usage early in the window
|
||||
const shouldWarn = thresholds.some(
|
||||
t => utilization >= t.utilization && timeProgress <= t.timePct,
|
||||
)
|
||||
|
||||
if (!shouldWarn) {
|
||||
return null
|
||||
}
|
||||
|
||||
return {
|
||||
status: 'allowed_warning',
|
||||
resetsAt,
|
||||
rateLimitType,
|
||||
utilization,
|
||||
unifiedRateLimitFallbackAvailable,
|
||||
isUsingOverage: false,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get early warning limits using header-based detection with time-relative fallback.
|
||||
* 1. First checks for surpassed-threshold header (new server-side approach)
|
||||
* 2. Falls back to time-relative thresholds (client-side calculation)
|
||||
*/
|
||||
function getEarlyWarningFromHeaders(
|
||||
headers: globalThis.Headers,
|
||||
unifiedRateLimitFallbackAvailable: boolean,
|
||||
): ClaudeAILimits | null {
|
||||
// Try header-based detection first (preferred when API sends the header)
|
||||
const headerBasedWarning = getHeaderBasedEarlyWarning(
|
||||
headers,
|
||||
unifiedRateLimitFallbackAvailable,
|
||||
)
|
||||
if (headerBasedWarning) {
|
||||
return headerBasedWarning
|
||||
}
|
||||
|
||||
// Fallback: Use time-relative thresholds (client-side calculation)
|
||||
// This catches users burning quota faster than sustainable
|
||||
for (const config of EARLY_WARNING_CONFIGS) {
|
||||
const timeRelativeWarning = getTimeRelativeEarlyWarning(
|
||||
headers,
|
||||
config,
|
||||
unifiedRateLimitFallbackAvailable,
|
||||
)
|
||||
if (timeRelativeWarning) {
|
||||
return timeRelativeWarning
|
||||
}
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
function computeNewLimitsFromHeaders(
|
||||
headers: globalThis.Headers,
|
||||
): ClaudeAILimits {
|
||||
const status =
|
||||
(headers.get('anthropic-ratelimit-unified-status') as QuotaStatus) ||
|
||||
'allowed'
|
||||
const resetsAtHeader = headers.get('anthropic-ratelimit-unified-reset')
|
||||
const resetsAt = resetsAtHeader ? Number(resetsAtHeader) : undefined
|
||||
const unifiedRateLimitFallbackAvailable =
|
||||
headers.get('anthropic-ratelimit-unified-fallback') === 'available'
|
||||
|
||||
// Headers for rate limit type and overage support
|
||||
const rateLimitType = headers.get(
|
||||
'anthropic-ratelimit-unified-representative-claim',
|
||||
) as RateLimitType | null
|
||||
const overageStatus = headers.get(
|
||||
'anthropic-ratelimit-unified-overage-status',
|
||||
) as QuotaStatus | null
|
||||
const overageResetsAtHeader = headers.get(
|
||||
'anthropic-ratelimit-unified-overage-reset',
|
||||
)
|
||||
const overageResetsAt = overageResetsAtHeader
|
||||
? Number(overageResetsAtHeader)
|
||||
: undefined
|
||||
|
||||
// Reason why overage is disabled (spending cap or wallet empty)
|
||||
const overageDisabledReason = headers.get(
|
||||
'anthropic-ratelimit-unified-overage-disabled-reason',
|
||||
) as OverageDisabledReason | null
|
||||
|
||||
// Determine if we're using overage (standard limits rejected but overage allowed)
|
||||
const isUsingOverage =
|
||||
status === 'rejected' &&
|
||||
(overageStatus === 'allowed' || overageStatus === 'allowed_warning')
|
||||
|
||||
// Check for early warning based on surpassed-threshold header
|
||||
// If status is allowed/allowed_warning and we find a surpassed threshold, show warning
|
||||
let finalStatus: QuotaStatus = status
|
||||
if (status === 'allowed' || status === 'allowed_warning') {
|
||||
const earlyWarning = getEarlyWarningFromHeaders(
|
||||
headers,
|
||||
unifiedRateLimitFallbackAvailable,
|
||||
)
|
||||
if (earlyWarning) {
|
||||
return earlyWarning
|
||||
}
|
||||
// No early warning threshold surpassed
|
||||
finalStatus = 'allowed'
|
||||
}
|
||||
|
||||
return {
|
||||
status: finalStatus,
|
||||
resetsAt,
|
||||
unifiedRateLimitFallbackAvailable,
|
||||
...(rateLimitType && { rateLimitType }),
|
||||
...(overageStatus && { overageStatus }),
|
||||
...(overageResetsAt && { overageResetsAt }),
|
||||
...(overageDisabledReason && { overageDisabledReason }),
|
||||
isUsingOverage,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Cache the extra usage disabled reason from API headers.
|
||||
*/
|
||||
function cacheExtraUsageDisabledReason(headers: globalThis.Headers): void {
|
||||
// A null reason means extra usage is enabled (no disabled reason header)
|
||||
const reason =
|
||||
headers.get('anthropic-ratelimit-unified-overage-disabled-reason') ?? null
|
||||
const cached = getGlobalConfig().cachedExtraUsageDisabledReason
|
||||
if (cached !== reason) {
|
||||
saveGlobalConfig(current => ({
|
||||
...current,
|
||||
cachedExtraUsageDisabledReason: reason,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
export function extractQuotaStatusFromHeaders(
|
||||
headers: globalThis.Headers,
|
||||
): void {
|
||||
// Check if we need to process rate limits
|
||||
const isSubscriber = isClaudeAISubscriber()
|
||||
|
||||
if (!shouldProcessRateLimits(isSubscriber)) {
|
||||
// If we have any rate limit state, clear it
|
||||
rawUtilization = {}
|
||||
if (currentLimits.status !== 'allowed' || currentLimits.resetsAt) {
|
||||
const defaultLimits: ClaudeAILimits = {
|
||||
status: 'allowed',
|
||||
unifiedRateLimitFallbackAvailable: false,
|
||||
isUsingOverage: false,
|
||||
}
|
||||
emitStatusChange(defaultLimits)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Process headers (applies mocks from /mock-limits command if active)
|
||||
const headersToUse = processRateLimitHeaders(headers)
|
||||
rawUtilization = extractRawUtilization(headersToUse)
|
||||
const newLimits = computeNewLimitsFromHeaders(headersToUse)
|
||||
|
||||
// Cache extra usage status (persists across sessions)
|
||||
cacheExtraUsageDisabledReason(headersToUse)
|
||||
|
||||
if (!isEqual(currentLimits, newLimits)) {
|
||||
emitStatusChange(newLimits)
|
||||
}
|
||||
}
|
||||
|
||||
export function extractQuotaStatusFromError(error: APIError): void {
|
||||
if (
|
||||
!shouldProcessRateLimits(isClaudeAISubscriber()) ||
|
||||
error.status !== 429
|
||||
) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
let newLimits = { ...currentLimits }
|
||||
if (error.headers) {
|
||||
// Process headers (applies mocks from /mock-limits command if active)
|
||||
const headersToUse = processRateLimitHeaders(error.headers)
|
||||
rawUtilization = extractRawUtilization(headersToUse)
|
||||
newLimits = computeNewLimitsFromHeaders(headersToUse)
|
||||
|
||||
// Cache extra usage status (persists across sessions)
|
||||
cacheExtraUsageDisabledReason(headersToUse)
|
||||
}
|
||||
// For errors, always set status to rejected even if headers are not present.
|
||||
newLimits.status = 'rejected'
|
||||
|
||||
if (!isEqual(currentLimits, newLimits)) {
|
||||
emitStatusChange(newLimits)
|
||||
}
|
||||
} catch (e) {
|
||||
logError(e as Error)
|
||||
}
|
||||
}
|
||||
23
src/services/claudeAiLimitsHook.ts
Normal file
23
src/services/claudeAiLimitsHook.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
import { useEffect, useState } from 'react'
|
||||
import {
|
||||
type ClaudeAILimits,
|
||||
currentLimits,
|
||||
statusListeners,
|
||||
} from './claudeAiLimits.js'
|
||||
|
||||
export function useClaudeAiLimits(): ClaudeAILimits {
|
||||
const [limits, setLimits] = useState<ClaudeAILimits>({ ...currentLimits })
|
||||
|
||||
useEffect(() => {
|
||||
const listener = (newLimits: ClaudeAILimits) => {
|
||||
setLimits({ ...newLimits })
|
||||
}
|
||||
statusListeners.add(listener)
|
||||
|
||||
return () => {
|
||||
statusListeners.delete(listener)
|
||||
}
|
||||
}, [])
|
||||
|
||||
return limits
|
||||
}
|
||||
153
src/services/compact/apiMicrocompact.ts
Normal file
153
src/services/compact/apiMicrocompact.ts
Normal file
@@ -0,0 +1,153 @@
|
||||
import { FILE_EDIT_TOOL_NAME } from 'src/tools/FileEditTool/constants.js'
|
||||
import { FILE_READ_TOOL_NAME } from 'src/tools/FileReadTool/prompt.js'
|
||||
import { FILE_WRITE_TOOL_NAME } from 'src/tools/FileWriteTool/prompt.js'
|
||||
import { GLOB_TOOL_NAME } from 'src/tools/GlobTool/prompt.js'
|
||||
import { GREP_TOOL_NAME } from 'src/tools/GrepTool/prompt.js'
|
||||
import { NOTEBOOK_EDIT_TOOL_NAME } from 'src/tools/NotebookEditTool/constants.js'
|
||||
import { WEB_FETCH_TOOL_NAME } from 'src/tools/WebFetchTool/prompt.js'
|
||||
import { WEB_SEARCH_TOOL_NAME } from 'src/tools/WebSearchTool/prompt.js'
|
||||
import { SHELL_TOOL_NAMES } from 'src/utils/shell/shellToolUtils.js'
|
||||
import { isEnvTruthy } from '../../utils/envUtils.js'
|
||||
|
||||
// docs: https://docs.google.com/document/d/1oCT4evvWTh3P6z-kcfNQwWTCxAhkoFndSaNS9Gm40uw/edit?tab=t.0
|
||||
|
||||
// Default values for context management strategies
|
||||
// Match client-side microcompact token values
|
||||
const DEFAULT_MAX_INPUT_TOKENS = 180_000 // Typical warning threshold
|
||||
const DEFAULT_TARGET_INPUT_TOKENS = 40_000 // Keep last 40k tokens like client-side
|
||||
|
||||
const TOOLS_CLEARABLE_RESULTS = [
|
||||
...SHELL_TOOL_NAMES,
|
||||
GLOB_TOOL_NAME,
|
||||
GREP_TOOL_NAME,
|
||||
FILE_READ_TOOL_NAME,
|
||||
WEB_FETCH_TOOL_NAME,
|
||||
WEB_SEARCH_TOOL_NAME,
|
||||
]
|
||||
|
||||
const TOOLS_CLEARABLE_USES = [
|
||||
FILE_EDIT_TOOL_NAME,
|
||||
FILE_WRITE_TOOL_NAME,
|
||||
NOTEBOOK_EDIT_TOOL_NAME,
|
||||
]
|
||||
|
||||
// Context management strategy types matching API documentation
|
||||
export type ContextEditStrategy =
|
||||
| {
|
||||
type: 'clear_tool_uses_20250919'
|
||||
trigger?: {
|
||||
type: 'input_tokens'
|
||||
value: number
|
||||
}
|
||||
keep?: {
|
||||
type: 'tool_uses'
|
||||
value: number
|
||||
}
|
||||
clear_tool_inputs?: boolean | string[]
|
||||
exclude_tools?: string[]
|
||||
clear_at_least?: {
|
||||
type: 'input_tokens'
|
||||
value: number
|
||||
}
|
||||
}
|
||||
| {
|
||||
type: 'clear_thinking_20251015'
|
||||
keep: { type: 'thinking_turns'; value: number } | 'all'
|
||||
}
|
||||
|
||||
// Context management configuration wrapper
|
||||
export type ContextManagementConfig = {
|
||||
edits: ContextEditStrategy[]
|
||||
}
|
||||
|
||||
// API-based microcompact implementation that uses native context management
|
||||
export function getAPIContextManagement(options?: {
|
||||
hasThinking?: boolean
|
||||
isRedactThinkingActive?: boolean
|
||||
clearAllThinking?: boolean
|
||||
}): ContextManagementConfig | undefined {
|
||||
const {
|
||||
hasThinking = false,
|
||||
isRedactThinkingActive = false,
|
||||
clearAllThinking = false,
|
||||
} = options ?? {}
|
||||
|
||||
const strategies: ContextEditStrategy[] = []
|
||||
|
||||
// Preserve thinking blocks in previous assistant turns. Skip when
|
||||
// redact-thinking is active — redacted blocks have no model-visible content.
|
||||
// When clearAllThinking is set (>1h idle = cache miss), keep only the last
|
||||
// thinking turn — the API schema requires value >= 1, and omitting the edit
|
||||
// falls back to the model-policy default (often "all"), which wouldn't clear.
|
||||
if (hasThinking && !isRedactThinkingActive) {
|
||||
strategies.push({
|
||||
type: 'clear_thinking_20251015',
|
||||
keep: clearAllThinking ? { type: 'thinking_turns', value: 1 } : 'all',
|
||||
})
|
||||
}
|
||||
|
||||
// Tool clearing strategies are ant-only
|
||||
if (process.env.USER_TYPE !== 'ant') {
|
||||
return strategies.length > 0 ? { edits: strategies } : undefined
|
||||
}
|
||||
|
||||
const useClearToolResults = isEnvTruthy(
|
||||
process.env.USE_API_CLEAR_TOOL_RESULTS,
|
||||
)
|
||||
const useClearToolUses = isEnvTruthy(process.env.USE_API_CLEAR_TOOL_USES)
|
||||
|
||||
// If no tool clearing strategy is enabled, return early
|
||||
if (!useClearToolResults && !useClearToolUses) {
|
||||
return strategies.length > 0 ? { edits: strategies } : undefined
|
||||
}
|
||||
|
||||
if (useClearToolResults) {
|
||||
const triggerThreshold = process.env.API_MAX_INPUT_TOKENS
|
||||
? parseInt(process.env.API_MAX_INPUT_TOKENS)
|
||||
: DEFAULT_MAX_INPUT_TOKENS
|
||||
const keepTarget = process.env.API_TARGET_INPUT_TOKENS
|
||||
? parseInt(process.env.API_TARGET_INPUT_TOKENS)
|
||||
: DEFAULT_TARGET_INPUT_TOKENS
|
||||
|
||||
const strategy: ContextEditStrategy = {
|
||||
type: 'clear_tool_uses_20250919',
|
||||
trigger: {
|
||||
type: 'input_tokens',
|
||||
value: triggerThreshold,
|
||||
},
|
||||
clear_at_least: {
|
||||
type: 'input_tokens',
|
||||
value: triggerThreshold - keepTarget,
|
||||
},
|
||||
clear_tool_inputs: TOOLS_CLEARABLE_RESULTS,
|
||||
}
|
||||
|
||||
strategies.push(strategy)
|
||||
}
|
||||
|
||||
if (useClearToolUses) {
|
||||
const triggerThreshold = process.env.API_MAX_INPUT_TOKENS
|
||||
? parseInt(process.env.API_MAX_INPUT_TOKENS)
|
||||
: DEFAULT_MAX_INPUT_TOKENS
|
||||
const keepTarget = process.env.API_TARGET_INPUT_TOKENS
|
||||
? parseInt(process.env.API_TARGET_INPUT_TOKENS)
|
||||
: DEFAULT_TARGET_INPUT_TOKENS
|
||||
|
||||
const strategy: ContextEditStrategy = {
|
||||
type: 'clear_tool_uses_20250919',
|
||||
trigger: {
|
||||
type: 'input_tokens',
|
||||
value: triggerThreshold,
|
||||
},
|
||||
clear_at_least: {
|
||||
type: 'input_tokens',
|
||||
value: triggerThreshold - keepTarget,
|
||||
},
|
||||
exclude_tools: TOOLS_CLEARABLE_USES,
|
||||
}
|
||||
|
||||
strategies.push(strategy)
|
||||
}
|
||||
|
||||
return strategies.length > 0 ? { edits: strategies } : undefined
|
||||
}
|
||||
351
src/services/compact/autoCompact.ts
Normal file
351
src/services/compact/autoCompact.ts
Normal file
@@ -0,0 +1,351 @@
|
||||
import { feature } from 'bun:bundle'
|
||||
import { markPostCompaction } from 'src/bootstrap/state.js'
|
||||
import { getSdkBetas } from '../../bootstrap/state.js'
|
||||
import type { QuerySource } from '../../constants/querySource.js'
|
||||
import type { ToolUseContext } from '../../Tool.js'
|
||||
import type { Message } from '../../types/message.js'
|
||||
import { getGlobalConfig } from '../../utils/config.js'
|
||||
import { getContextWindowForModel } from '../../utils/context.js'
|
||||
import { logForDebugging } from '../../utils/debug.js'
|
||||
import { isEnvTruthy } from '../../utils/envUtils.js'
|
||||
import { hasExactErrorMessage } from '../../utils/errors.js'
|
||||
import type { CacheSafeParams } from '../../utils/forkedAgent.js'
|
||||
import { logError } from '../../utils/log.js'
|
||||
import { tokenCountWithEstimation } from '../../utils/tokens.js'
|
||||
import { getFeatureValue_CACHED_MAY_BE_STALE } from '../analytics/growthbook.js'
|
||||
import { getMaxOutputTokensForModel } from '../api/claude.js'
|
||||
import { notifyCompaction } from '../api/promptCacheBreakDetection.js'
|
||||
import { setLastSummarizedMessageId } from '../SessionMemory/sessionMemoryUtils.js'
|
||||
import {
|
||||
type CompactionResult,
|
||||
compactConversation,
|
||||
ERROR_MESSAGE_USER_ABORT,
|
||||
type RecompactionInfo,
|
||||
} from './compact.js'
|
||||
import { runPostCompactCleanup } from './postCompactCleanup.js'
|
||||
import { trySessionMemoryCompaction } from './sessionMemoryCompact.js'
|
||||
|
||||
// Reserve this many tokens for output during compaction
|
||||
// Based on p99.99 of compact summary output being 17,387 tokens.
|
||||
const MAX_OUTPUT_TOKENS_FOR_SUMMARY = 20_000
|
||||
|
||||
// Returns the context window size minus the max output tokens for the model
|
||||
export function getEffectiveContextWindowSize(model: string): number {
|
||||
const reservedTokensForSummary = Math.min(
|
||||
getMaxOutputTokensForModel(model),
|
||||
MAX_OUTPUT_TOKENS_FOR_SUMMARY,
|
||||
)
|
||||
let contextWindow = getContextWindowForModel(model, getSdkBetas())
|
||||
|
||||
const autoCompactWindow = process.env.CLAUDE_CODE_AUTO_COMPACT_WINDOW
|
||||
if (autoCompactWindow) {
|
||||
const parsed = parseInt(autoCompactWindow, 10)
|
||||
if (!isNaN(parsed) && parsed > 0) {
|
||||
contextWindow = Math.min(contextWindow, parsed)
|
||||
}
|
||||
}
|
||||
|
||||
return contextWindow - reservedTokensForSummary
|
||||
}
|
||||
|
||||
export type AutoCompactTrackingState = {
|
||||
compacted: boolean
|
||||
turnCounter: number
|
||||
// Unique ID per turn
|
||||
turnId: string
|
||||
// Consecutive autocompact failures. Reset on success.
|
||||
// Used as a circuit breaker to stop retrying when the context is
|
||||
// irrecoverably over the limit (e.g., prompt_too_long).
|
||||
consecutiveFailures?: number
|
||||
}
|
||||
|
||||
export const AUTOCOMPACT_BUFFER_TOKENS = 13_000
|
||||
export const WARNING_THRESHOLD_BUFFER_TOKENS = 20_000
|
||||
export const ERROR_THRESHOLD_BUFFER_TOKENS = 20_000
|
||||
export const MANUAL_COMPACT_BUFFER_TOKENS = 3_000
|
||||
|
||||
// Stop trying autocompact after this many consecutive failures.
|
||||
// BQ 2026-03-10: 1,279 sessions had 50+ consecutive failures (up to 3,272)
|
||||
// in a single session, wasting ~250K API calls/day globally.
|
||||
const MAX_CONSECUTIVE_AUTOCOMPACT_FAILURES = 3
|
||||
|
||||
export function getAutoCompactThreshold(model: string): number {
|
||||
const effectiveContextWindow = getEffectiveContextWindowSize(model)
|
||||
|
||||
const autocompactThreshold =
|
||||
effectiveContextWindow - AUTOCOMPACT_BUFFER_TOKENS
|
||||
|
||||
// Override for easier testing of autocompact
|
||||
const envPercent = process.env.CLAUDE_AUTOCOMPACT_PCT_OVERRIDE
|
||||
if (envPercent) {
|
||||
const parsed = parseFloat(envPercent)
|
||||
if (!isNaN(parsed) && parsed > 0 && parsed <= 100) {
|
||||
const percentageThreshold = Math.floor(
|
||||
effectiveContextWindow * (parsed / 100),
|
||||
)
|
||||
return Math.min(percentageThreshold, autocompactThreshold)
|
||||
}
|
||||
}
|
||||
|
||||
return autocompactThreshold
|
||||
}
|
||||
|
||||
export function calculateTokenWarningState(
|
||||
tokenUsage: number,
|
||||
model: string,
|
||||
): {
|
||||
percentLeft: number
|
||||
isAboveWarningThreshold: boolean
|
||||
isAboveErrorThreshold: boolean
|
||||
isAboveAutoCompactThreshold: boolean
|
||||
isAtBlockingLimit: boolean
|
||||
} {
|
||||
const autoCompactThreshold = getAutoCompactThreshold(model)
|
||||
const threshold = isAutoCompactEnabled()
|
||||
? autoCompactThreshold
|
||||
: getEffectiveContextWindowSize(model)
|
||||
|
||||
const percentLeft = Math.max(
|
||||
0,
|
||||
Math.round(((threshold - tokenUsage) / threshold) * 100),
|
||||
)
|
||||
|
||||
const warningThreshold = threshold - WARNING_THRESHOLD_BUFFER_TOKENS
|
||||
const errorThreshold = threshold - ERROR_THRESHOLD_BUFFER_TOKENS
|
||||
|
||||
const isAboveWarningThreshold = tokenUsage >= warningThreshold
|
||||
const isAboveErrorThreshold = tokenUsage >= errorThreshold
|
||||
|
||||
const isAboveAutoCompactThreshold =
|
||||
isAutoCompactEnabled() && tokenUsage >= autoCompactThreshold
|
||||
|
||||
const actualContextWindow = getEffectiveContextWindowSize(model)
|
||||
const defaultBlockingLimit =
|
||||
actualContextWindow - MANUAL_COMPACT_BUFFER_TOKENS
|
||||
|
||||
// Allow override for testing
|
||||
const blockingLimitOverride = process.env.CLAUDE_CODE_BLOCKING_LIMIT_OVERRIDE
|
||||
const parsedOverride = blockingLimitOverride
|
||||
? parseInt(blockingLimitOverride, 10)
|
||||
: NaN
|
||||
const blockingLimit =
|
||||
!isNaN(parsedOverride) && parsedOverride > 0
|
||||
? parsedOverride
|
||||
: defaultBlockingLimit
|
||||
|
||||
const isAtBlockingLimit = tokenUsage >= blockingLimit
|
||||
|
||||
return {
|
||||
percentLeft,
|
||||
isAboveWarningThreshold,
|
||||
isAboveErrorThreshold,
|
||||
isAboveAutoCompactThreshold,
|
||||
isAtBlockingLimit,
|
||||
}
|
||||
}
|
||||
|
||||
export function isAutoCompactEnabled(): boolean {
|
||||
if (isEnvTruthy(process.env.DISABLE_COMPACT)) {
|
||||
return false
|
||||
}
|
||||
// Allow disabling just auto-compact (keeps manual /compact working)
|
||||
if (isEnvTruthy(process.env.DISABLE_AUTO_COMPACT)) {
|
||||
return false
|
||||
}
|
||||
// Check if user has disabled auto-compact in their settings
|
||||
const userConfig = getGlobalConfig()
|
||||
return userConfig.autoCompactEnabled
|
||||
}
|
||||
|
||||
export async function shouldAutoCompact(
|
||||
messages: Message[],
|
||||
model: string,
|
||||
querySource?: QuerySource,
|
||||
// Snip removes messages but the surviving assistant's usage still reflects
|
||||
// pre-snip context, so tokenCountWithEstimation can't see the savings.
|
||||
// Subtract the rough-delta that snip already computed.
|
||||
snipTokensFreed = 0,
|
||||
): Promise<boolean> {
|
||||
// Recursion guards. session_memory and compact are forked agents that
|
||||
// would deadlock.
|
||||
if (querySource === 'session_memory' || querySource === 'compact') {
|
||||
return false
|
||||
}
|
||||
// marble_origami is the ctx-agent — if ITS context blows up and
|
||||
// autocompact fires, runPostCompactCleanup calls resetContextCollapse()
|
||||
// which destroys the MAIN thread's committed log (module-level state
|
||||
// shared across forks). Inside feature() so the string DCEs from
|
||||
// external builds (it's in excluded-strings.txt).
|
||||
if (feature('CONTEXT_COLLAPSE')) {
|
||||
if (querySource === 'marble_origami') {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if (!isAutoCompactEnabled()) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Reactive-only mode: suppress proactive autocompact, let reactive compact
|
||||
// catch the API's prompt-too-long. feature() wrapper keeps the flag string
|
||||
// out of external builds (REACTIVE_COMPACT is ant-only).
|
||||
// Note: returning false here also means autoCompactIfNeeded never reaches
|
||||
// trySessionMemoryCompaction in the query loop — the /compact call site
|
||||
// still tries session memory first. Revisit if reactive-only graduates.
|
||||
if (feature('REACTIVE_COMPACT')) {
|
||||
if (getFeatureValue_CACHED_MAY_BE_STALE('tengu_cobalt_raccoon', false)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Context-collapse mode: same suppression. Collapse IS the context
|
||||
// management system when it's on — the 90% commit / 95% blocking-spawn
|
||||
// flow owns the headroom problem. Autocompact firing at effective-13k
|
||||
// (~93% of effective) sits right between collapse's commit-start (90%)
|
||||
// and blocking (95%), so it would race collapse and usually win, nuking
|
||||
// granular context that collapse was about to save. Gating here rather
|
||||
// than in isAutoCompactEnabled() keeps reactiveCompact alive as the 413
|
||||
// fallback (it consults isAutoCompactEnabled directly) and leaves
|
||||
// sessionMemory + manual /compact working.
|
||||
//
|
||||
// Consult isContextCollapseEnabled (not the raw gate) so the
|
||||
// CLAUDE_CONTEXT_COLLAPSE env override is honored here too. require()
|
||||
// inside the block breaks the init-time cycle (this file exports
|
||||
// getEffectiveContextWindowSize which collapse's index imports).
|
||||
if (feature('CONTEXT_COLLAPSE')) {
|
||||
/* eslint-disable @typescript-eslint/no-require-imports */
|
||||
const { isContextCollapseEnabled } =
|
||||
require('../contextCollapse/index.js') as typeof import('../contextCollapse/index.js')
|
||||
/* eslint-enable @typescript-eslint/no-require-imports */
|
||||
if (isContextCollapseEnabled()) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
const tokenCount = tokenCountWithEstimation(messages) - snipTokensFreed
|
||||
const threshold = getAutoCompactThreshold(model)
|
||||
const effectiveWindow = getEffectiveContextWindowSize(model)
|
||||
|
||||
logForDebugging(
|
||||
`autocompact: tokens=${tokenCount} threshold=${threshold} effectiveWindow=${effectiveWindow}${snipTokensFreed > 0 ? ` snipFreed=${snipTokensFreed}` : ''}`,
|
||||
)
|
||||
|
||||
const { isAboveAutoCompactThreshold } = calculateTokenWarningState(
|
||||
tokenCount,
|
||||
model,
|
||||
)
|
||||
|
||||
return isAboveAutoCompactThreshold
|
||||
}
|
||||
|
||||
export async function autoCompactIfNeeded(
|
||||
messages: Message[],
|
||||
toolUseContext: ToolUseContext,
|
||||
cacheSafeParams: CacheSafeParams,
|
||||
querySource?: QuerySource,
|
||||
tracking?: AutoCompactTrackingState,
|
||||
snipTokensFreed?: number,
|
||||
): Promise<{
|
||||
wasCompacted: boolean
|
||||
compactionResult?: CompactionResult
|
||||
consecutiveFailures?: number
|
||||
}> {
|
||||
if (isEnvTruthy(process.env.DISABLE_COMPACT)) {
|
||||
return { wasCompacted: false }
|
||||
}
|
||||
|
||||
// Circuit breaker: stop retrying after N consecutive failures.
|
||||
// Without this, sessions where context is irrecoverably over the limit
|
||||
// hammer the API with doomed compaction attempts on every turn.
|
||||
if (
|
||||
tracking?.consecutiveFailures !== undefined &&
|
||||
tracking.consecutiveFailures >= MAX_CONSECUTIVE_AUTOCOMPACT_FAILURES
|
||||
) {
|
||||
return { wasCompacted: false }
|
||||
}
|
||||
|
||||
const model = toolUseContext.options.mainLoopModel
|
||||
const shouldCompact = await shouldAutoCompact(
|
||||
messages,
|
||||
model,
|
||||
querySource,
|
||||
snipTokensFreed,
|
||||
)
|
||||
|
||||
if (!shouldCompact) {
|
||||
return { wasCompacted: false }
|
||||
}
|
||||
|
||||
const recompactionInfo: RecompactionInfo = {
|
||||
isRecompactionInChain: tracking?.compacted === true,
|
||||
turnsSincePreviousCompact: tracking?.turnCounter ?? -1,
|
||||
previousCompactTurnId: tracking?.turnId,
|
||||
autoCompactThreshold: getAutoCompactThreshold(model),
|
||||
querySource,
|
||||
}
|
||||
|
||||
// EXPERIMENT: Try session memory compaction first
|
||||
const sessionMemoryResult = await trySessionMemoryCompaction(
|
||||
messages,
|
||||
toolUseContext.agentId,
|
||||
recompactionInfo.autoCompactThreshold,
|
||||
)
|
||||
if (sessionMemoryResult) {
|
||||
// Reset lastSummarizedMessageId since session memory compaction prunes messages
|
||||
// and the old message UUID will no longer exist after the REPL replaces messages
|
||||
setLastSummarizedMessageId(undefined)
|
||||
runPostCompactCleanup(querySource)
|
||||
// Reset cache read baseline so the post-compact drop isn't flagged as a
|
||||
// break. compactConversation does this internally; SM-compact doesn't.
|
||||
// BQ 2026-03-01: missing this made 20% of tengu_prompt_cache_break events
|
||||
// false positives (systemPromptChanged=true, timeSinceLastAssistantMsg=-1).
|
||||
if (feature('PROMPT_CACHE_BREAK_DETECTION')) {
|
||||
notifyCompaction(querySource ?? 'compact', toolUseContext.agentId)
|
||||
}
|
||||
markPostCompaction()
|
||||
return {
|
||||
wasCompacted: true,
|
||||
compactionResult: sessionMemoryResult,
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
const compactionResult = await compactConversation(
|
||||
messages,
|
||||
toolUseContext,
|
||||
cacheSafeParams,
|
||||
true, // Suppress user questions for autocompact
|
||||
undefined, // No custom instructions for autocompact
|
||||
true, // isAutoCompact
|
||||
recompactionInfo,
|
||||
)
|
||||
|
||||
// Reset lastSummarizedMessageId since legacy compaction replaces all messages
|
||||
// and the old message UUID will no longer exist in the new messages array
|
||||
setLastSummarizedMessageId(undefined)
|
||||
runPostCompactCleanup(querySource)
|
||||
|
||||
return {
|
||||
wasCompacted: true,
|
||||
compactionResult,
|
||||
// Reset failure count on success
|
||||
consecutiveFailures: 0,
|
||||
}
|
||||
} catch (error) {
|
||||
if (!hasExactErrorMessage(error, ERROR_MESSAGE_USER_ABORT)) {
|
||||
logError(error)
|
||||
}
|
||||
// Increment consecutive failure count for circuit breaker.
|
||||
// The caller threads this through autoCompactTracking so the
|
||||
// next query loop iteration can skip futile retry attempts.
|
||||
const prevFailures = tracking?.consecutiveFailures ?? 0
|
||||
const nextFailures = prevFailures + 1
|
||||
if (nextFailures >= MAX_CONSECUTIVE_AUTOCOMPACT_FAILURES) {
|
||||
logForDebugging(
|
||||
`autocompact: circuit breaker tripped after ${nextFailures} consecutive failures — skipping future attempts this session`,
|
||||
{ level: 'warn' },
|
||||
)
|
||||
}
|
||||
return { wasCompacted: false, consecutiveFailures: nextFailures }
|
||||
}
|
||||
}
|
||||
1705
src/services/compact/compact.ts
Normal file
1705
src/services/compact/compact.ts
Normal file
File diff suppressed because it is too large
Load Diff
16
src/services/compact/compactWarningHook.ts
Normal file
16
src/services/compact/compactWarningHook.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
import { useSyncExternalStore } from 'react'
|
||||
import { compactWarningStore } from './compactWarningState.js'
|
||||
|
||||
/**
|
||||
* React hook to subscribe to compact warning suppression state.
|
||||
*
|
||||
* Lives in its own file so that compactWarningState.ts stays React-free:
|
||||
* microCompact.ts imports the pure state functions, and pulling React into
|
||||
* that module graph would drag it into the print-mode startup path.
|
||||
*/
|
||||
export function useCompactWarningSuppression(): boolean {
|
||||
return useSyncExternalStore(
|
||||
compactWarningStore.subscribe,
|
||||
compactWarningStore.getState,
|
||||
)
|
||||
}
|
||||
18
src/services/compact/compactWarningState.ts
Normal file
18
src/services/compact/compactWarningState.ts
Normal file
@@ -0,0 +1,18 @@
|
||||
import { createStore } from '../../state/store.js'
|
||||
|
||||
/**
|
||||
* Tracks whether the "context left until autocompact" warning should be suppressed.
|
||||
* We suppress immediately after successful compaction since we don't have accurate
|
||||
* token counts until the next API response.
|
||||
*/
|
||||
export const compactWarningStore = createStore<boolean>(false)
|
||||
|
||||
/** Suppress the compact warning. Call after successful compaction. */
|
||||
export function suppressCompactWarning(): void {
|
||||
compactWarningStore.setState(() => true)
|
||||
}
|
||||
|
||||
/** Clear the compact warning suppression. Called at start of new compact attempt. */
|
||||
export function clearCompactWarningSuppression(): void {
|
||||
compactWarningStore.setState(() => false)
|
||||
}
|
||||
63
src/services/compact/grouping.ts
Normal file
63
src/services/compact/grouping.ts
Normal file
@@ -0,0 +1,63 @@
|
||||
import type { Message } from '../../types/message.js'
|
||||
|
||||
/**
|
||||
* Groups messages at API-round boundaries: one group per API round-trip.
|
||||
* A boundary fires when a NEW assistant response begins (different
|
||||
* message.id from the prior assistant). For well-formed conversations
|
||||
* this is an API-safe split point — the API contract requires every
|
||||
* tool_use to be resolved before the next assistant turn, so pairing
|
||||
* validity falls out of the assistant-id boundary. For malformed inputs
|
||||
* (dangling tool_use after resume/truncation) the fork's
|
||||
* ensureToolResultPairing repairs the split at API time.
|
||||
*
|
||||
* Replaces the prior human-turn grouping (boundaries only at real user
|
||||
* prompts) with finer-grained API-round grouping, allowing reactive
|
||||
* compact to operate on single-prompt agentic sessions (SDK/CCR/eval
|
||||
* callers) where the entire workload is one human turn.
|
||||
*
|
||||
* Extracted to its own file to break the compact.ts ↔ compactMessages.ts
|
||||
* cycle (CC-1180) — the cycle shifted module-init order enough to surface
|
||||
* a latent ws CJS/ESM resolution race in CI shard-2.
|
||||
*/
|
||||
export function groupMessagesByApiRound(messages: Message[]): Message[][] {
|
||||
const groups: Message[][] = []
|
||||
let current: Message[] = []
|
||||
// message.id of the most recently seen assistant. This is the sole
|
||||
// boundary gate: streaming chunks from the same API response share an
|
||||
// id, so boundaries only fire at the start of a genuinely new round.
|
||||
// normalizeMessages yields one AssistantMessage per content block, and
|
||||
// StreamingToolExecutor interleaves tool_results between chunks live
|
||||
// (yield order, not concat order — see query.ts:613). The id check
|
||||
// correctly keeps `[tu_A(id=X), result_A, tu_B(id=X)]` in one group.
|
||||
let lastAssistantId: string | undefined
|
||||
|
||||
// In a well-formed conversation the API contract guarantees every
|
||||
// tool_use is resolved before the next assistant turn, so lastAssistantId
|
||||
// alone is a sufficient boundary gate. Tracking unresolved tool_use IDs
|
||||
// would only do work when the conversation is malformed (dangling tool_use
|
||||
// after resume-from-partial-batch or max_tokens truncation) — and in that
|
||||
// case it pins the gate shut forever, merging all subsequent rounds into
|
||||
// one group. We let those boundaries fire; the summarizer fork's own
|
||||
// ensureToolResultPairing at claude.ts:1136 repairs the dangling tu at
|
||||
// API time.
|
||||
for (const msg of messages) {
|
||||
if (
|
||||
msg.type === 'assistant' &&
|
||||
msg.message.id !== lastAssistantId &&
|
||||
current.length > 0
|
||||
) {
|
||||
groups.push(current)
|
||||
current = [msg]
|
||||
} else {
|
||||
current.push(msg)
|
||||
}
|
||||
if (msg.type === 'assistant') {
|
||||
lastAssistantId = msg.message.id
|
||||
}
|
||||
}
|
||||
|
||||
if (current.length > 0) {
|
||||
groups.push(current)
|
||||
}
|
||||
return groups
|
||||
}
|
||||
530
src/services/compact/microCompact.ts
Normal file
530
src/services/compact/microCompact.ts
Normal file
@@ -0,0 +1,530 @@
|
||||
import { feature } from 'bun:bundle'
|
||||
import type { ToolResultBlockParam } from '@anthropic-ai/sdk/resources/index.mjs'
|
||||
import type { QuerySource } from '../../constants/querySource.js'
|
||||
import type { ToolUseContext } from '../../Tool.js'
|
||||
import { FILE_EDIT_TOOL_NAME } from '../../tools/FileEditTool/constants.js'
|
||||
import { FILE_READ_TOOL_NAME } from '../../tools/FileReadTool/prompt.js'
|
||||
import { FILE_WRITE_TOOL_NAME } from '../../tools/FileWriteTool/prompt.js'
|
||||
import { GLOB_TOOL_NAME } from '../../tools/GlobTool/prompt.js'
|
||||
import { GREP_TOOL_NAME } from '../../tools/GrepTool/prompt.js'
|
||||
import { WEB_FETCH_TOOL_NAME } from '../../tools/WebFetchTool/prompt.js'
|
||||
import { WEB_SEARCH_TOOL_NAME } from '../../tools/WebSearchTool/prompt.js'
|
||||
import type { Message } from '../../types/message.js'
|
||||
import { logForDebugging } from '../../utils/debug.js'
|
||||
import { getMainLoopModel } from '../../utils/model/model.js'
|
||||
import { SHELL_TOOL_NAMES } from '../../utils/shell/shellToolUtils.js'
|
||||
import { jsonStringify } from '../../utils/slowOperations.js'
|
||||
import {
|
||||
type AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
logEvent,
|
||||
} from '../analytics/index.js'
|
||||
import { notifyCacheDeletion } from '../api/promptCacheBreakDetection.js'
|
||||
import { roughTokenCountEstimation } from '../tokenEstimation.js'
|
||||
import {
|
||||
clearCompactWarningSuppression,
|
||||
suppressCompactWarning,
|
||||
} from './compactWarningState.js'
|
||||
import {
|
||||
getTimeBasedMCConfig,
|
||||
type TimeBasedMCConfig,
|
||||
} from './timeBasedMCConfig.js'
|
||||
|
||||
// Inline from utils/toolResultStorage.ts — importing that file pulls in
|
||||
// sessionStorage → utils/messages → services/api/errors, completing a
|
||||
// circular-deps loop back through this file via promptCacheBreakDetection.
|
||||
// Drift is caught by a test asserting equality with the source-of-truth.
|
||||
export const TIME_BASED_MC_CLEARED_MESSAGE = '[Old tool result content cleared]'
|
||||
|
||||
const IMAGE_MAX_TOKEN_SIZE = 2000
|
||||
|
||||
// Only compact these tools
|
||||
const COMPACTABLE_TOOLS = new Set<string>([
|
||||
FILE_READ_TOOL_NAME,
|
||||
...SHELL_TOOL_NAMES,
|
||||
GREP_TOOL_NAME,
|
||||
GLOB_TOOL_NAME,
|
||||
WEB_SEARCH_TOOL_NAME,
|
||||
WEB_FETCH_TOOL_NAME,
|
||||
FILE_EDIT_TOOL_NAME,
|
||||
FILE_WRITE_TOOL_NAME,
|
||||
])
|
||||
|
||||
// --- Cached microcompact state (ant-only, gated by feature('CACHED_MICROCOMPACT')) ---
|
||||
|
||||
// Lazy-initialized cached MC module and state to avoid importing in external builds.
|
||||
// The imports and state live inside feature() checks for dead code elimination.
|
||||
let cachedMCModule: typeof import('./cachedMicrocompact.js') | null = null
|
||||
let cachedMCState: import('./cachedMicrocompact.js').CachedMCState | null = null
|
||||
let pendingCacheEdits:
|
||||
| import('./cachedMicrocompact.js').CacheEditsBlock
|
||||
| null = null
|
||||
|
||||
async function getCachedMCModule(): Promise<
|
||||
typeof import('./cachedMicrocompact.js')
|
||||
> {
|
||||
if (!cachedMCModule) {
|
||||
cachedMCModule = await import('./cachedMicrocompact.js')
|
||||
}
|
||||
return cachedMCModule
|
||||
}
|
||||
|
||||
function ensureCachedMCState(): import('./cachedMicrocompact.js').CachedMCState {
|
||||
if (!cachedMCState && cachedMCModule) {
|
||||
cachedMCState = cachedMCModule.createCachedMCState()
|
||||
}
|
||||
if (!cachedMCState) {
|
||||
throw new Error(
|
||||
'cachedMCState not initialized — getCachedMCModule() must be called first',
|
||||
)
|
||||
}
|
||||
return cachedMCState
|
||||
}
|
||||
|
||||
/**
|
||||
* Get new pending cache edits to be included in the next API request.
|
||||
* Returns null if there are no new pending edits.
|
||||
* Clears the pending state (caller must pin them after insertion).
|
||||
*/
|
||||
export function consumePendingCacheEdits():
|
||||
| import('./cachedMicrocompact.js').CacheEditsBlock
|
||||
| null {
|
||||
const edits = pendingCacheEdits
|
||||
pendingCacheEdits = null
|
||||
return edits
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all previously-pinned cache edits that must be re-sent at their
|
||||
* original positions for cache hits.
|
||||
*/
|
||||
export function getPinnedCacheEdits(): import('./cachedMicrocompact.js').PinnedCacheEdits[] {
|
||||
if (!cachedMCState) {
|
||||
return []
|
||||
}
|
||||
return cachedMCState.pinnedEdits
|
||||
}
|
||||
|
||||
/**
|
||||
* Pin a new cache_edits block to a specific user message position.
|
||||
* Called after inserting new edits so they are re-sent in subsequent calls.
|
||||
*/
|
||||
export function pinCacheEdits(
|
||||
userMessageIndex: number,
|
||||
block: import('./cachedMicrocompact.js').CacheEditsBlock,
|
||||
): void {
|
||||
if (cachedMCState) {
|
||||
cachedMCState.pinnedEdits.push({ userMessageIndex, block })
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Marks all registered tools as sent to the API.
|
||||
* Called after a successful API response.
|
||||
*/
|
||||
export function markToolsSentToAPIState(): void {
|
||||
if (cachedMCState && cachedMCModule) {
|
||||
cachedMCModule.markToolsSentToAPI(cachedMCState)
|
||||
}
|
||||
}
|
||||
|
||||
export function resetMicrocompactState(): void {
|
||||
if (cachedMCState && cachedMCModule) {
|
||||
cachedMCModule.resetCachedMCState(cachedMCState)
|
||||
}
|
||||
pendingCacheEdits = null
|
||||
}
|
||||
|
||||
// Helper to calculate tool result tokens
|
||||
function calculateToolResultTokens(block: ToolResultBlockParam): number {
|
||||
if (!block.content) {
|
||||
return 0
|
||||
}
|
||||
|
||||
if (typeof block.content === 'string') {
|
||||
return roughTokenCountEstimation(block.content)
|
||||
}
|
||||
|
||||
// Array of TextBlockParam | ImageBlockParam | DocumentBlockParam
|
||||
return block.content.reduce((sum, item) => {
|
||||
if (item.type === 'text') {
|
||||
return sum + roughTokenCountEstimation(item.text)
|
||||
} else if (item.type === 'image' || item.type === 'document') {
|
||||
// Images/documents are approximately 2000 tokens regardless of format
|
||||
return sum + IMAGE_MAX_TOKEN_SIZE
|
||||
}
|
||||
return sum
|
||||
}, 0)
|
||||
}
|
||||
|
||||
/**
|
||||
* Estimate token count for messages by extracting text content
|
||||
* Used for rough token estimation when we don't have accurate API counts
|
||||
* Pads estimate by 4/3 to be conservative since we're approximating
|
||||
*/
|
||||
export function estimateMessageTokens(messages: Message[]): number {
|
||||
let totalTokens = 0
|
||||
|
||||
for (const message of messages) {
|
||||
if (message.type !== 'user' && message.type !== 'assistant') {
|
||||
continue
|
||||
}
|
||||
|
||||
if (!Array.isArray(message.message.content)) {
|
||||
continue
|
||||
}
|
||||
|
||||
for (const block of message.message.content) {
|
||||
if (block.type === 'text') {
|
||||
totalTokens += roughTokenCountEstimation(block.text)
|
||||
} else if (block.type === 'tool_result') {
|
||||
totalTokens += calculateToolResultTokens(block)
|
||||
} else if (block.type === 'image' || block.type === 'document') {
|
||||
totalTokens += IMAGE_MAX_TOKEN_SIZE
|
||||
} else if (block.type === 'thinking') {
|
||||
// Match roughTokenCountEstimationForBlock: count only the thinking
|
||||
// text, not the JSON wrapper or signature (signature is metadata,
|
||||
// not model-tokenized content).
|
||||
totalTokens += roughTokenCountEstimation(block.thinking)
|
||||
} else if (block.type === 'redacted_thinking') {
|
||||
totalTokens += roughTokenCountEstimation(block.data)
|
||||
} else if (block.type === 'tool_use') {
|
||||
// Match roughTokenCountEstimationForBlock: count name + input,
|
||||
// not the JSON wrapper or id field.
|
||||
totalTokens += roughTokenCountEstimation(
|
||||
block.name + jsonStringify(block.input ?? {}),
|
||||
)
|
||||
} else {
|
||||
// server_tool_use, web_search_tool_result, etc.
|
||||
totalTokens += roughTokenCountEstimation(jsonStringify(block))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pad estimate by 4/3 to be conservative since we're approximating
|
||||
return Math.ceil(totalTokens * (4 / 3))
|
||||
}
|
||||
|
||||
export type PendingCacheEdits = {
|
||||
trigger: 'auto'
|
||||
deletedToolIds: string[]
|
||||
// Baseline cumulative cache_deleted_input_tokens from the previous API response,
|
||||
// used to compute the per-operation delta (the API value is sticky/cumulative)
|
||||
baselineCacheDeletedTokens: number
|
||||
}
|
||||
|
||||
export type MicrocompactResult = {
|
||||
messages: Message[]
|
||||
compactionInfo?: {
|
||||
pendingCacheEdits?: PendingCacheEdits
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Walk messages and collect tool_use IDs whose tool name is in
|
||||
* COMPACTABLE_TOOLS, in encounter order. Shared by both microcompact paths.
|
||||
*/
|
||||
function collectCompactableToolIds(messages: Message[]): string[] {
|
||||
const ids: string[] = []
|
||||
for (const message of messages) {
|
||||
if (
|
||||
message.type === 'assistant' &&
|
||||
Array.isArray(message.message.content)
|
||||
) {
|
||||
for (const block of message.message.content) {
|
||||
if (block.type === 'tool_use' && COMPACTABLE_TOOLS.has(block.name)) {
|
||||
ids.push(block.id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// Prefix-match because promptCategory.ts sets the querySource to
|
||||
// 'repl_main_thread:outputStyle:<style>' when a non-default output style
|
||||
// is active. The bare 'repl_main_thread' is only used for the default style.
|
||||
// query.ts:350/1451 use the same startsWith pattern; the pre-existing
|
||||
// cached-MC `=== 'repl_main_thread'` check was a latent bug — users with a
|
||||
// non-default output style were silently excluded from cached MC.
|
||||
function isMainThreadSource(querySource: QuerySource | undefined): boolean {
|
||||
return !querySource || querySource.startsWith('repl_main_thread')
|
||||
}
|
||||
|
||||
export async function microcompactMessages(
|
||||
messages: Message[],
|
||||
toolUseContext?: ToolUseContext,
|
||||
querySource?: QuerySource,
|
||||
): Promise<MicrocompactResult> {
|
||||
// Clear suppression flag at start of new microcompact attempt
|
||||
clearCompactWarningSuppression()
|
||||
|
||||
// Time-based trigger runs first and short-circuits. If the gap since the
|
||||
// last assistant message exceeds the threshold, the server cache has expired
|
||||
// and the full prefix will be rewritten regardless — so content-clear old
|
||||
// tool results now, before the request, to shrink what gets rewritten.
|
||||
// Cached MC (cache-editing) is skipped when this fires: editing assumes a
|
||||
// warm cache, and we just established it's cold.
|
||||
const timeBasedResult = maybeTimeBasedMicrocompact(messages, querySource)
|
||||
if (timeBasedResult) {
|
||||
return timeBasedResult
|
||||
}
|
||||
|
||||
// Only run cached MC for the main thread to prevent forked agents
|
||||
// (session_memory, prompt_suggestion, etc.) from registering their
|
||||
// tool_results in the global cachedMCState, which would cause the main
|
||||
// thread to try deleting tools that don't exist in its own conversation.
|
||||
if (feature('CACHED_MICROCOMPACT')) {
|
||||
const mod = await getCachedMCModule()
|
||||
const model = toolUseContext?.options.mainLoopModel ?? getMainLoopModel()
|
||||
if (
|
||||
mod.isCachedMicrocompactEnabled() &&
|
||||
mod.isModelSupportedForCacheEditing(model) &&
|
||||
isMainThreadSource(querySource)
|
||||
) {
|
||||
return await cachedMicrocompactPath(messages, querySource)
|
||||
}
|
||||
}
|
||||
|
||||
// Legacy microcompact path removed — tengu_cache_plum_violet is always true.
|
||||
// For contexts where cached microcompact is not available (external builds,
|
||||
// non-ant users, unsupported models, sub-agents), no compaction happens here;
|
||||
// autocompact handles context pressure instead.
|
||||
return { messages }
|
||||
}
|
||||
|
||||
/**
|
||||
* Cached microcompact path - uses cache editing API to remove tool results
|
||||
* without invalidating the cached prefix.
|
||||
*
|
||||
* Key differences from regular microcompact:
|
||||
* - Does NOT modify local message content (cache_reference and cache_edits are added at API layer)
|
||||
* - Uses count-based trigger/keep thresholds from GrowthBook config
|
||||
* - Takes precedence over regular microcompact (no disk persistence)
|
||||
* - Tracks tool results and queues cache edits for the API layer
|
||||
*/
|
||||
async function cachedMicrocompactPath(
|
||||
messages: Message[],
|
||||
querySource: QuerySource | undefined,
|
||||
): Promise<MicrocompactResult> {
|
||||
const mod = await getCachedMCModule()
|
||||
const state = ensureCachedMCState()
|
||||
const config = mod.getCachedMCConfig()
|
||||
|
||||
const compactableToolIds = new Set(collectCompactableToolIds(messages))
|
||||
// Second pass: register tool results grouped by user message
|
||||
for (const message of messages) {
|
||||
if (message.type === 'user' && Array.isArray(message.message.content)) {
|
||||
const groupIds: string[] = []
|
||||
for (const block of message.message.content) {
|
||||
if (
|
||||
block.type === 'tool_result' &&
|
||||
compactableToolIds.has(block.tool_use_id) &&
|
||||
!state.registeredTools.has(block.tool_use_id)
|
||||
) {
|
||||
mod.registerToolResult(state, block.tool_use_id)
|
||||
groupIds.push(block.tool_use_id)
|
||||
}
|
||||
}
|
||||
mod.registerToolMessage(state, groupIds)
|
||||
}
|
||||
}
|
||||
|
||||
const toolsToDelete = mod.getToolResultsToDelete(state)
|
||||
|
||||
if (toolsToDelete.length > 0) {
|
||||
// Create and queue the cache_edits block for the API layer
|
||||
const cacheEdits = mod.createCacheEditsBlock(state, toolsToDelete)
|
||||
if (cacheEdits) {
|
||||
pendingCacheEdits = cacheEdits
|
||||
}
|
||||
|
||||
logForDebugging(
|
||||
`Cached MC deleting ${toolsToDelete.length} tool(s): ${toolsToDelete.join(', ')}`,
|
||||
)
|
||||
|
||||
// Log the event
|
||||
logEvent('tengu_cached_microcompact', {
|
||||
toolsDeleted: toolsToDelete.length,
|
||||
deletedToolIds: toolsToDelete.join(
|
||||
',',
|
||||
) as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
activeToolCount: state.toolOrder.length - state.deletedRefs.size,
|
||||
triggerType:
|
||||
'auto' as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
threshold: config.triggerThreshold,
|
||||
keepRecent: config.keepRecent,
|
||||
})
|
||||
|
||||
// Suppress warning after successful compaction
|
||||
suppressCompactWarning()
|
||||
|
||||
// Notify cache break detection that cache reads will legitimately drop
|
||||
if (feature('PROMPT_CACHE_BREAK_DETECTION')) {
|
||||
// Pass the actual querySource — isMainThreadSource now prefix-matches
|
||||
// so output-style variants enter here, and getTrackingKey keys on the
|
||||
// full source string, not the 'repl_main_thread' prefix.
|
||||
notifyCacheDeletion(querySource ?? 'repl_main_thread')
|
||||
}
|
||||
|
||||
// Return messages unchanged - cache_reference and cache_edits are added at API layer
|
||||
// Boundary message is deferred until after API response so we can use
|
||||
// actual cache_deleted_input_tokens from the API instead of client-side estimates
|
||||
// Capture the baseline cumulative cache_deleted_input_tokens from the last
|
||||
// assistant message so we can compute a per-operation delta after the API call
|
||||
const lastAsst = messages.findLast(m => m.type === 'assistant')
|
||||
const baseline =
|
||||
lastAsst?.type === 'assistant'
|
||||
? ((
|
||||
lastAsst.message.usage as unknown as Record<
|
||||
string,
|
||||
number | undefined
|
||||
>
|
||||
)?.cache_deleted_input_tokens ?? 0)
|
||||
: 0
|
||||
|
||||
return {
|
||||
messages,
|
||||
compactionInfo: {
|
||||
pendingCacheEdits: {
|
||||
trigger: 'auto',
|
||||
deletedToolIds: toolsToDelete,
|
||||
baselineCacheDeletedTokens: baseline,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// No compaction needed, return messages unchanged
|
||||
return { messages }
|
||||
}
|
||||
|
||||
/**
|
||||
* Time-based microcompact: when the gap since the last main-loop assistant
|
||||
* message exceeds the configured threshold, content-clear all but the most
|
||||
* recent N compactable tool results.
|
||||
*
|
||||
* Returns null when the trigger doesn't fire (disabled, wrong source, gap
|
||||
* under threshold, nothing to clear) — caller falls through to other paths.
|
||||
*
|
||||
* Unlike cached MC, this mutates message content directly. The cache is cold,
|
||||
* so there's no cached prefix to preserve via cache_edits.
|
||||
*/
|
||||
/**
|
||||
* Check whether the time-based trigger should fire for this request.
|
||||
*
|
||||
* Returns the measured gap (minutes since last assistant message) when the
|
||||
* trigger fires, or null when it doesn't (disabled, wrong source, under
|
||||
* threshold, no prior assistant, unparseable timestamp).
|
||||
*
|
||||
* Extracted so other pre-request paths (e.g. snip force-apply) can consult
|
||||
* the same predicate without coupling to the tool-result clearing action.
|
||||
*/
|
||||
export function evaluateTimeBasedTrigger(
|
||||
messages: Message[],
|
||||
querySource: QuerySource | undefined,
|
||||
): { gapMinutes: number; config: TimeBasedMCConfig } | null {
|
||||
const config = getTimeBasedMCConfig()
|
||||
// Require an explicit main-thread querySource. isMainThreadSource treats
|
||||
// undefined as main-thread (for cached-MC backward-compat), but several
|
||||
// callers (/context, /compact, analyzeContext) invoke microcompactMessages
|
||||
// without a source for analysis-only purposes — they should not trigger.
|
||||
if (!config.enabled || !querySource || !isMainThreadSource(querySource)) {
|
||||
return null
|
||||
}
|
||||
const lastAssistant = messages.findLast(m => m.type === 'assistant')
|
||||
if (!lastAssistant) {
|
||||
return null
|
||||
}
|
||||
const gapMinutes =
|
||||
(Date.now() - new Date(lastAssistant.timestamp).getTime()) / 60_000
|
||||
if (!Number.isFinite(gapMinutes) || gapMinutes < config.gapThresholdMinutes) {
|
||||
return null
|
||||
}
|
||||
return { gapMinutes, config }
|
||||
}
|
||||
|
||||
function maybeTimeBasedMicrocompact(
|
||||
messages: Message[],
|
||||
querySource: QuerySource | undefined,
|
||||
): MicrocompactResult | null {
|
||||
const trigger = evaluateTimeBasedTrigger(messages, querySource)
|
||||
if (!trigger) {
|
||||
return null
|
||||
}
|
||||
const { gapMinutes, config } = trigger
|
||||
|
||||
const compactableIds = collectCompactableToolIds(messages)
|
||||
|
||||
// Floor at 1: slice(-0) returns the full array (paradoxically keeps
|
||||
// everything), and clearing ALL results leaves the model with zero working
|
||||
// context. Neither degenerate is sensible — always keep at least the last.
|
||||
const keepRecent = Math.max(1, config.keepRecent)
|
||||
const keepSet = new Set(compactableIds.slice(-keepRecent))
|
||||
const clearSet = new Set(compactableIds.filter(id => !keepSet.has(id)))
|
||||
|
||||
if (clearSet.size === 0) {
|
||||
return null
|
||||
}
|
||||
|
||||
let tokensSaved = 0
|
||||
const result: Message[] = messages.map(message => {
|
||||
if (message.type !== 'user' || !Array.isArray(message.message.content)) {
|
||||
return message
|
||||
}
|
||||
let touched = false
|
||||
const newContent = message.message.content.map(block => {
|
||||
if (
|
||||
block.type === 'tool_result' &&
|
||||
clearSet.has(block.tool_use_id) &&
|
||||
block.content !== TIME_BASED_MC_CLEARED_MESSAGE
|
||||
) {
|
||||
tokensSaved += calculateToolResultTokens(block)
|
||||
touched = true
|
||||
return { ...block, content: TIME_BASED_MC_CLEARED_MESSAGE }
|
||||
}
|
||||
return block
|
||||
})
|
||||
if (!touched) return message
|
||||
return {
|
||||
...message,
|
||||
message: { ...message.message, content: newContent },
|
||||
}
|
||||
})
|
||||
|
||||
if (tokensSaved === 0) {
|
||||
return null
|
||||
}
|
||||
|
||||
logEvent('tengu_time_based_microcompact', {
|
||||
gapMinutes: Math.round(gapMinutes),
|
||||
gapThresholdMinutes: config.gapThresholdMinutes,
|
||||
toolsCleared: clearSet.size,
|
||||
toolsKept: keepSet.size,
|
||||
keepRecent: config.keepRecent,
|
||||
tokensSaved,
|
||||
})
|
||||
|
||||
logForDebugging(
|
||||
`[TIME-BASED MC] gap ${Math.round(gapMinutes)}min > ${config.gapThresholdMinutes}min, cleared ${clearSet.size} tool results (~${tokensSaved} tokens), kept last ${keepSet.size}`,
|
||||
)
|
||||
|
||||
suppressCompactWarning()
|
||||
// Cached-MC state (module-level) holds tool IDs registered on prior turns.
|
||||
// We just content-cleared some of those tools AND invalidated the server
|
||||
// cache by changing prompt content. If cached-MC runs next turn with the
|
||||
// stale state, it would try to cache_edit tools whose server-side entries
|
||||
// no longer exist. Reset it.
|
||||
resetMicrocompactState()
|
||||
// We just changed the prompt content — the next response's cache read will
|
||||
// be low, but that's us, not a break. Tell the detector to expect a drop.
|
||||
// notifyCacheDeletion (not notifyCompaction) because it's already imported
|
||||
// here and achieves the same false-positive suppression — adding the second
|
||||
// symbol to the import was flagged by the circular-deps check.
|
||||
// Pass the actual querySource: getTrackingKey returns the full source string
|
||||
// (e.g. 'repl_main_thread:outputStyle:custom'), not just the prefix.
|
||||
if (feature('PROMPT_CACHE_BREAK_DETECTION') && querySource) {
|
||||
notifyCacheDeletion(querySource)
|
||||
}
|
||||
|
||||
return { messages: result }
|
||||
}
|
||||
77
src/services/compact/postCompactCleanup.ts
Normal file
77
src/services/compact/postCompactCleanup.ts
Normal file
@@ -0,0 +1,77 @@
|
||||
import { feature } from 'bun:bundle'
|
||||
import type { QuerySource } from '../../constants/querySource.js'
|
||||
import { clearSystemPromptSections } from '../../constants/systemPromptSections.js'
|
||||
import { getUserContext } from '../../context.js'
|
||||
import { clearSpeculativeChecks } from '../../tools/BashTool/bashPermissions.js'
|
||||
import { clearClassifierApprovals } from '../../utils/classifierApprovals.js'
|
||||
import { resetGetMemoryFilesCache } from '../../utils/claudemd.js'
|
||||
import { clearSessionMessagesCache } from '../../utils/sessionStorage.js'
|
||||
import { clearBetaTracingState } from '../../utils/telemetry/betaSessionTracing.js'
|
||||
import { resetMicrocompactState } from './microCompact.js'
|
||||
|
||||
/**
|
||||
* Run cleanup of caches and tracking state after compaction.
|
||||
* Call this after both auto-compact and manual /compact to free memory
|
||||
* held by tracking structures that are invalidated by compaction.
|
||||
*
|
||||
* Note: We intentionally do NOT clear invoked skill content here.
|
||||
* Skill content must survive across multiple compactions so that
|
||||
* createSkillAttachmentIfNeeded() can include the full skill text
|
||||
* in subsequent compaction attachments.
|
||||
*
|
||||
* querySource: pass the compacting query's source so we can skip
|
||||
* resets that would clobber main-thread module-level state. Subagents
|
||||
* (agent:*) run in the same process and share module-level state
|
||||
* (context-collapse store, getMemoryFiles one-shot hook flag,
|
||||
* getUserContext cache); resetting those when a SUBAGENT compacts
|
||||
* would corrupt the MAIN thread's state. All compaction callers should
|
||||
* pass querySource — undefined is only safe for callers that are
|
||||
* genuinely main-thread-only (/compact, /clear).
|
||||
*/
|
||||
export function runPostCompactCleanup(querySource?: QuerySource): void {
|
||||
// Subagents (agent:*) run in the same process and share module-level
|
||||
// state with the main thread. Only reset main-thread module-level state
|
||||
// (context-collapse, memory file cache) for main-thread compacts.
|
||||
// Same startsWith pattern as isMainThread (index.ts:188).
|
||||
const isMainThreadCompact =
|
||||
querySource === undefined ||
|
||||
querySource.startsWith('repl_main_thread') ||
|
||||
querySource === 'sdk'
|
||||
|
||||
resetMicrocompactState()
|
||||
if (feature('CONTEXT_COLLAPSE')) {
|
||||
if (isMainThreadCompact) {
|
||||
/* eslint-disable @typescript-eslint/no-require-imports */
|
||||
;(
|
||||
require('../contextCollapse/index.js') as typeof import('../contextCollapse/index.js')
|
||||
).resetContextCollapse()
|
||||
/* eslint-enable @typescript-eslint/no-require-imports */
|
||||
}
|
||||
}
|
||||
if (isMainThreadCompact) {
|
||||
// getUserContext is a memoized outer layer wrapping getClaudeMds() →
|
||||
// getMemoryFiles(). If only the inner getMemoryFiles cache is cleared,
|
||||
// the next turn hits the getUserContext cache and never reaches
|
||||
// getMemoryFiles(), so the armed InstructionsLoaded hook never fires.
|
||||
// Manual /compact already clears this explicitly at its call sites;
|
||||
// auto-compact and reactive-compact did not — this centralizes the
|
||||
// clear so all compaction paths behave consistently.
|
||||
getUserContext.cache.clear?.()
|
||||
resetGetMemoryFilesCache('compact')
|
||||
}
|
||||
clearSystemPromptSections()
|
||||
clearClassifierApprovals()
|
||||
clearSpeculativeChecks()
|
||||
// Intentionally NOT calling resetSentSkillNames(): re-injecting the full
|
||||
// skill_listing (~4K tokens) post-compact is pure cache_creation. The
|
||||
// model still has SkillTool in schema, invoked_skills preserves used
|
||||
// skills, and dynamic additions are handled by skillChangeDetector /
|
||||
// cacheUtils resets. See compactConversation() for full rationale.
|
||||
clearBetaTracingState()
|
||||
if (feature('COMMIT_ATTRIBUTION')) {
|
||||
void import('../../utils/attributionHooks.js').then(m =>
|
||||
m.sweepFileContentCache(),
|
||||
)
|
||||
}
|
||||
clearSessionMessagesCache()
|
||||
}
|
||||
374
src/services/compact/prompt.ts
Normal file
374
src/services/compact/prompt.ts
Normal file
@@ -0,0 +1,374 @@
|
||||
import { feature } from 'bun:bundle'
|
||||
import type { PartialCompactDirection } from '../../types/message.js'
|
||||
|
||||
// Dead code elimination: conditional import for proactive mode
|
||||
/* eslint-disable @typescript-eslint/no-require-imports */
|
||||
const proactiveModule =
|
||||
feature('PROACTIVE') || feature('KAIROS')
|
||||
? (require('../../proactive/index.js') as typeof import('../../proactive/index.js'))
|
||||
: null
|
||||
/* eslint-enable @typescript-eslint/no-require-imports */
|
||||
|
||||
// Aggressive no-tools preamble. The cache-sharing fork path inherits the
|
||||
// parent's full tool set (required for cache-key match), and on Sonnet 4.6+
|
||||
// adaptive-thinking models the model sometimes attempts a tool call despite
|
||||
// the weaker trailer instruction. With maxTurns: 1, a denied tool call means
|
||||
// no text output → falls through to the streaming fallback (2.79% on 4.6 vs
|
||||
// 0.01% on 4.5). Putting this FIRST and making it explicit about rejection
|
||||
// consequences prevents the wasted turn.
|
||||
const NO_TOOLS_PREAMBLE = `CRITICAL: Respond with TEXT ONLY. Do NOT call any tools.
|
||||
|
||||
- Do NOT use Read, Bash, Grep, Glob, Edit, Write, or ANY other tool.
|
||||
- You already have all the context you need in the conversation above.
|
||||
- Tool calls will be REJECTED and will waste your only turn — you will fail the task.
|
||||
- Your entire response must be plain text: an <analysis> block followed by a <summary> block.
|
||||
|
||||
`
|
||||
|
||||
// Two variants: BASE scopes to "the conversation", PARTIAL scopes to "the
|
||||
// recent messages". The <analysis> block is a drafting scratchpad that
|
||||
// formatCompactSummary() strips before the summary reaches context.
|
||||
const DETAILED_ANALYSIS_INSTRUCTION_BASE = `Before providing your final summary, wrap your analysis in <analysis> tags to organize your thoughts and ensure you've covered all necessary points. In your analysis process:
|
||||
|
||||
1. Chronologically analyze each message and section of the conversation. For each section thoroughly identify:
|
||||
- The user's explicit requests and intents
|
||||
- Your approach to addressing the user's requests
|
||||
- Key decisions, technical concepts and code patterns
|
||||
- Specific details like:
|
||||
- file names
|
||||
- full code snippets
|
||||
- function signatures
|
||||
- file edits
|
||||
- Errors that you ran into and how you fixed them
|
||||
- Pay special attention to specific user feedback that you received, especially if the user told you to do something differently.
|
||||
2. Double-check for technical accuracy and completeness, addressing each required element thoroughly.`
|
||||
|
||||
const DETAILED_ANALYSIS_INSTRUCTION_PARTIAL = `Before providing your final summary, wrap your analysis in <analysis> tags to organize your thoughts and ensure you've covered all necessary points. In your analysis process:
|
||||
|
||||
1. Analyze the recent messages chronologically. For each section thoroughly identify:
|
||||
- The user's explicit requests and intents
|
||||
- Your approach to addressing the user's requests
|
||||
- Key decisions, technical concepts and code patterns
|
||||
- Specific details like:
|
||||
- file names
|
||||
- full code snippets
|
||||
- function signatures
|
||||
- file edits
|
||||
- Errors that you ran into and how you fixed them
|
||||
- Pay special attention to specific user feedback that you received, especially if the user told you to do something differently.
|
||||
2. Double-check for technical accuracy and completeness, addressing each required element thoroughly.`
|
||||
|
||||
const BASE_COMPACT_PROMPT = `Your task is to create a detailed summary of the conversation so far, paying close attention to the user's explicit requests and your previous actions.
|
||||
This summary should be thorough in capturing technical details, code patterns, and architectural decisions that would be essential for continuing development work without losing context.
|
||||
|
||||
${DETAILED_ANALYSIS_INSTRUCTION_BASE}
|
||||
|
||||
Your summary should include the following sections:
|
||||
|
||||
1. Primary Request and Intent: Capture all of the user's explicit requests and intents in detail
|
||||
2. Key Technical Concepts: List all important technical concepts, technologies, and frameworks discussed.
|
||||
3. Files and Code Sections: Enumerate specific files and code sections examined, modified, or created. Pay special attention to the most recent messages and include full code snippets where applicable and include a summary of why this file read or edit is important.
|
||||
4. Errors and fixes: List all errors that you ran into, and how you fixed them. Pay special attention to specific user feedback that you received, especially if the user told you to do something differently.
|
||||
5. Problem Solving: Document problems solved and any ongoing troubleshooting efforts.
|
||||
6. All user messages: List ALL user messages that are not tool results. These are critical for understanding the users' feedback and changing intent.
|
||||
7. Pending Tasks: Outline any pending tasks that you have explicitly been asked to work on.
|
||||
8. Current Work: Describe in detail precisely what was being worked on immediately before this summary request, paying special attention to the most recent messages from both user and assistant. Include file names and code snippets where applicable.
|
||||
9. Optional Next Step: List the next step that you will take that is related to the most recent work you were doing. IMPORTANT: ensure that this step is DIRECTLY in line with the user's most recent explicit requests, and the task you were working on immediately before this summary request. If your last task was concluded, then only list next steps if they are explicitly in line with the users request. Do not start on tangential requests or really old requests that were already completed without confirming with the user first.
|
||||
If there is a next step, include direct quotes from the most recent conversation showing exactly what task you were working on and where you left off. This should be verbatim to ensure there's no drift in task interpretation.
|
||||
|
||||
Here's an example of how your output should be structured:
|
||||
|
||||
<example>
|
||||
<analysis>
|
||||
[Your thought process, ensuring all points are covered thoroughly and accurately]
|
||||
</analysis>
|
||||
|
||||
<summary>
|
||||
1. Primary Request and Intent:
|
||||
[Detailed description]
|
||||
|
||||
2. Key Technical Concepts:
|
||||
- [Concept 1]
|
||||
- [Concept 2]
|
||||
- [...]
|
||||
|
||||
3. Files and Code Sections:
|
||||
- [File Name 1]
|
||||
- [Summary of why this file is important]
|
||||
- [Summary of the changes made to this file, if any]
|
||||
- [Important Code Snippet]
|
||||
- [File Name 2]
|
||||
- [Important Code Snippet]
|
||||
- [...]
|
||||
|
||||
4. Errors and fixes:
|
||||
- [Detailed description of error 1]:
|
||||
- [How you fixed the error]
|
||||
- [User feedback on the error if any]
|
||||
- [...]
|
||||
|
||||
5. Problem Solving:
|
||||
[Description of solved problems and ongoing troubleshooting]
|
||||
|
||||
6. All user messages:
|
||||
- [Detailed non tool use user message]
|
||||
- [...]
|
||||
|
||||
7. Pending Tasks:
|
||||
- [Task 1]
|
||||
- [Task 2]
|
||||
- [...]
|
||||
|
||||
8. Current Work:
|
||||
[Precise description of current work]
|
||||
|
||||
9. Optional Next Step:
|
||||
[Optional Next step to take]
|
||||
|
||||
</summary>
|
||||
</example>
|
||||
|
||||
Please provide your summary based on the conversation so far, following this structure and ensuring precision and thoroughness in your response.
|
||||
|
||||
There may be additional summarization instructions provided in the included context. If so, remember to follow these instructions when creating the above summary. Examples of instructions include:
|
||||
<example>
|
||||
## Compact Instructions
|
||||
When summarizing the conversation focus on typescript code changes and also remember the mistakes you made and how you fixed them.
|
||||
</example>
|
||||
|
||||
<example>
|
||||
# Summary instructions
|
||||
When you are using compact - please focus on test output and code changes. Include file reads verbatim.
|
||||
</example>
|
||||
`
|
||||
|
||||
const PARTIAL_COMPACT_PROMPT = `Your task is to create a detailed summary of the RECENT portion of the conversation — the messages that follow earlier retained context. The earlier messages are being kept intact and do NOT need to be summarized. Focus your summary on what was discussed, learned, and accomplished in the recent messages only.
|
||||
|
||||
${DETAILED_ANALYSIS_INSTRUCTION_PARTIAL}
|
||||
|
||||
Your summary should include the following sections:
|
||||
|
||||
1. Primary Request and Intent: Capture the user's explicit requests and intents from the recent messages
|
||||
2. Key Technical Concepts: List important technical concepts, technologies, and frameworks discussed recently.
|
||||
3. Files and Code Sections: Enumerate specific files and code sections examined, modified, or created. Include full code snippets where applicable and include a summary of why this file read or edit is important.
|
||||
4. Errors and fixes: List errors encountered and how they were fixed.
|
||||
5. Problem Solving: Document problems solved and any ongoing troubleshooting efforts.
|
||||
6. All user messages: List ALL user messages from the recent portion that are not tool results.
|
||||
7. Pending Tasks: Outline any pending tasks from the recent messages.
|
||||
8. Current Work: Describe precisely what was being worked on immediately before this summary request.
|
||||
9. Optional Next Step: List the next step related to the most recent work. Include direct quotes from the most recent conversation.
|
||||
|
||||
Here's an example of how your output should be structured:
|
||||
|
||||
<example>
|
||||
<analysis>
|
||||
[Your thought process, ensuring all points are covered thoroughly and accurately]
|
||||
</analysis>
|
||||
|
||||
<summary>
|
||||
1. Primary Request and Intent:
|
||||
[Detailed description]
|
||||
|
||||
2. Key Technical Concepts:
|
||||
- [Concept 1]
|
||||
- [Concept 2]
|
||||
|
||||
3. Files and Code Sections:
|
||||
- [File Name 1]
|
||||
- [Summary of why this file is important]
|
||||
- [Important Code Snippet]
|
||||
|
||||
4. Errors and fixes:
|
||||
- [Error description]:
|
||||
- [How you fixed it]
|
||||
|
||||
5. Problem Solving:
|
||||
[Description]
|
||||
|
||||
6. All user messages:
|
||||
- [Detailed non tool use user message]
|
||||
|
||||
7. Pending Tasks:
|
||||
- [Task 1]
|
||||
|
||||
8. Current Work:
|
||||
[Precise description of current work]
|
||||
|
||||
9. Optional Next Step:
|
||||
[Optional Next step to take]
|
||||
|
||||
</summary>
|
||||
</example>
|
||||
|
||||
Please provide your summary based on the RECENT messages only (after the retained earlier context), following this structure and ensuring precision and thoroughness in your response.
|
||||
`
|
||||
|
||||
// 'up_to': model sees only the summarized prefix (cache hit). Summary will
|
||||
// precede kept recent messages, hence "Context for Continuing Work" section.
|
||||
const PARTIAL_COMPACT_UP_TO_PROMPT = `Your task is to create a detailed summary of this conversation. This summary will be placed at the start of a continuing session; newer messages that build on this context will follow after your summary (you do not see them here). Summarize thoroughly so that someone reading only your summary and then the newer messages can fully understand what happened and continue the work.
|
||||
|
||||
${DETAILED_ANALYSIS_INSTRUCTION_BASE}
|
||||
|
||||
Your summary should include the following sections:
|
||||
|
||||
1. Primary Request and Intent: Capture the user's explicit requests and intents in detail
|
||||
2. Key Technical Concepts: List important technical concepts, technologies, and frameworks discussed.
|
||||
3. Files and Code Sections: Enumerate specific files and code sections examined, modified, or created. Include full code snippets where applicable and include a summary of why this file read or edit is important.
|
||||
4. Errors and fixes: List errors encountered and how they were fixed.
|
||||
5. Problem Solving: Document problems solved and any ongoing troubleshooting efforts.
|
||||
6. All user messages: List ALL user messages that are not tool results.
|
||||
7. Pending Tasks: Outline any pending tasks.
|
||||
8. Work Completed: Describe what was accomplished by the end of this portion.
|
||||
9. Context for Continuing Work: Summarize any context, decisions, or state that would be needed to understand and continue the work in subsequent messages.
|
||||
|
||||
Here's an example of how your output should be structured:
|
||||
|
||||
<example>
|
||||
<analysis>
|
||||
[Your thought process, ensuring all points are covered thoroughly and accurately]
|
||||
</analysis>
|
||||
|
||||
<summary>
|
||||
1. Primary Request and Intent:
|
||||
[Detailed description]
|
||||
|
||||
2. Key Technical Concepts:
|
||||
- [Concept 1]
|
||||
- [Concept 2]
|
||||
|
||||
3. Files and Code Sections:
|
||||
- [File Name 1]
|
||||
- [Summary of why this file is important]
|
||||
- [Important Code Snippet]
|
||||
|
||||
4. Errors and fixes:
|
||||
- [Error description]:
|
||||
- [How you fixed it]
|
||||
|
||||
5. Problem Solving:
|
||||
[Description]
|
||||
|
||||
6. All user messages:
|
||||
- [Detailed non tool use user message]
|
||||
|
||||
7. Pending Tasks:
|
||||
- [Task 1]
|
||||
|
||||
8. Work Completed:
|
||||
[Description of what was accomplished]
|
||||
|
||||
9. Context for Continuing Work:
|
||||
[Key context, decisions, or state needed to continue the work]
|
||||
|
||||
</summary>
|
||||
</example>
|
||||
|
||||
Please provide your summary following this structure, ensuring precision and thoroughness in your response.
|
||||
`
|
||||
|
||||
const NO_TOOLS_TRAILER =
|
||||
'\n\nREMINDER: Do NOT call any tools. Respond with plain text only — ' +
|
||||
'an <analysis> block followed by a <summary> block. ' +
|
||||
'Tool calls will be rejected and you will fail the task.'
|
||||
|
||||
export function getPartialCompactPrompt(
|
||||
customInstructions?: string,
|
||||
direction: PartialCompactDirection = 'from',
|
||||
): string {
|
||||
const template =
|
||||
direction === 'up_to'
|
||||
? PARTIAL_COMPACT_UP_TO_PROMPT
|
||||
: PARTIAL_COMPACT_PROMPT
|
||||
let prompt = NO_TOOLS_PREAMBLE + template
|
||||
|
||||
if (customInstructions && customInstructions.trim() !== '') {
|
||||
prompt += `\n\nAdditional Instructions:\n${customInstructions}`
|
||||
}
|
||||
|
||||
prompt += NO_TOOLS_TRAILER
|
||||
|
||||
return prompt
|
||||
}
|
||||
|
||||
export function getCompactPrompt(customInstructions?: string): string {
|
||||
let prompt = NO_TOOLS_PREAMBLE + BASE_COMPACT_PROMPT
|
||||
|
||||
if (customInstructions && customInstructions.trim() !== '') {
|
||||
prompt += `\n\nAdditional Instructions:\n${customInstructions}`
|
||||
}
|
||||
|
||||
prompt += NO_TOOLS_TRAILER
|
||||
|
||||
return prompt
|
||||
}
|
||||
|
||||
/**
|
||||
* Formats the compact summary by stripping the <analysis> drafting scratchpad
|
||||
* and replacing <summary> XML tags with readable section headers.
|
||||
* @param summary The raw summary string potentially containing <analysis> and <summary> XML tags
|
||||
* @returns The formatted summary with analysis stripped and summary tags replaced by headers
|
||||
*/
|
||||
export function formatCompactSummary(summary: string): string {
|
||||
let formattedSummary = summary
|
||||
|
||||
// Strip analysis section — it's a drafting scratchpad that improves summary
|
||||
// quality but has no informational value once the summary is written.
|
||||
formattedSummary = formattedSummary.replace(
|
||||
/<analysis>[\s\S]*?<\/analysis>/,
|
||||
'',
|
||||
)
|
||||
|
||||
// Extract and format summary section
|
||||
const summaryMatch = formattedSummary.match(/<summary>([\s\S]*?)<\/summary>/)
|
||||
if (summaryMatch) {
|
||||
const content = summaryMatch[1] || ''
|
||||
formattedSummary = formattedSummary.replace(
|
||||
/<summary>[\s\S]*?<\/summary>/,
|
||||
`Summary:\n${content.trim()}`,
|
||||
)
|
||||
}
|
||||
|
||||
// Clean up extra whitespace between sections
|
||||
formattedSummary = formattedSummary.replace(/\n\n+/g, '\n\n')
|
||||
|
||||
return formattedSummary.trim()
|
||||
}
|
||||
|
||||
export function getCompactUserSummaryMessage(
|
||||
summary: string,
|
||||
suppressFollowUpQuestions?: boolean,
|
||||
transcriptPath?: string,
|
||||
recentMessagesPreserved?: boolean,
|
||||
): string {
|
||||
const formattedSummary = formatCompactSummary(summary)
|
||||
|
||||
let baseSummary = `This session is being continued from a previous conversation that ran out of context. The summary below covers the earlier portion of the conversation.
|
||||
|
||||
${formattedSummary}`
|
||||
|
||||
if (transcriptPath) {
|
||||
baseSummary += `\n\nIf you need specific details from before compaction (like exact code snippets, error messages, or content you generated), read the full transcript at: ${transcriptPath}`
|
||||
}
|
||||
|
||||
if (recentMessagesPreserved) {
|
||||
baseSummary += `\n\nRecent messages are preserved verbatim.`
|
||||
}
|
||||
|
||||
if (suppressFollowUpQuestions) {
|
||||
let continuation = `${baseSummary}
|
||||
Continue the conversation from where it left off without asking the user any further questions. Resume directly — do not acknowledge the summary, do not recap what was happening, do not preface with "I'll continue" or similar. Pick up the last task as if the break never happened.`
|
||||
|
||||
if (
|
||||
(feature('PROACTIVE') || feature('KAIROS')) &&
|
||||
proactiveModule?.isProactiveActive()
|
||||
) {
|
||||
continuation += `
|
||||
|
||||
You are running in autonomous/proactive mode. This is NOT a first wake-up — you were already working autonomously before compaction. Continue your work loop: pick up where you left off based on the summary above. Do not greet the user or ask what to work on.`
|
||||
}
|
||||
|
||||
return continuation
|
||||
}
|
||||
|
||||
return baseSummary
|
||||
}
|
||||
630
src/services/compact/sessionMemoryCompact.ts
Normal file
630
src/services/compact/sessionMemoryCompact.ts
Normal file
@@ -0,0 +1,630 @@
|
||||
/**
|
||||
* EXPERIMENT: Session memory compaction
|
||||
*/
|
||||
|
||||
import type { AgentId } from '../../types/ids.js'
|
||||
import type { HookResultMessage, Message } from '../../types/message.js'
|
||||
import { logForDebugging } from '../../utils/debug.js'
|
||||
import { isEnvTruthy } from '../../utils/envUtils.js'
|
||||
import { errorMessage } from '../../utils/errors.js'
|
||||
import {
|
||||
createCompactBoundaryMessage,
|
||||
createUserMessage,
|
||||
isCompactBoundaryMessage,
|
||||
} from '../../utils/messages.js'
|
||||
import { getMainLoopModel } from '../../utils/model/model.js'
|
||||
import { getSessionMemoryPath } from '../../utils/permissions/filesystem.js'
|
||||
import { processSessionStartHooks } from '../../utils/sessionStart.js'
|
||||
import { getTranscriptPath } from '../../utils/sessionStorage.js'
|
||||
import { tokenCountFromLastAPIResponse } from '../../utils/tokens.js'
|
||||
import { extractDiscoveredToolNames } from '../../utils/toolSearch.js'
|
||||
import {
|
||||
getDynamicConfig_BLOCKS_ON_INIT,
|
||||
getFeatureValue_CACHED_MAY_BE_STALE,
|
||||
} from '../analytics/growthbook.js'
|
||||
import { logEvent } from '../analytics/index.js'
|
||||
import {
|
||||
isSessionMemoryEmpty,
|
||||
truncateSessionMemoryForCompact,
|
||||
} from '../SessionMemory/prompts.js'
|
||||
import {
|
||||
getLastSummarizedMessageId,
|
||||
getSessionMemoryContent,
|
||||
waitForSessionMemoryExtraction,
|
||||
} from '../SessionMemory/sessionMemoryUtils.js'
|
||||
import {
|
||||
annotateBoundaryWithPreservedSegment,
|
||||
buildPostCompactMessages,
|
||||
type CompactionResult,
|
||||
createPlanAttachmentIfNeeded,
|
||||
} from './compact.js'
|
||||
import { estimateMessageTokens } from './microCompact.js'
|
||||
import { getCompactUserSummaryMessage } from './prompt.js'
|
||||
|
||||
/**
|
||||
* Configuration for session memory compaction thresholds
|
||||
*/
|
||||
export type SessionMemoryCompactConfig = {
|
||||
/** Minimum tokens to preserve after compaction */
|
||||
minTokens: number
|
||||
/** Minimum number of messages with text blocks to keep */
|
||||
minTextBlockMessages: number
|
||||
/** Maximum tokens to preserve after compaction (hard cap) */
|
||||
maxTokens: number
|
||||
}
|
||||
|
||||
// Default configuration values (exported for use in tests)
|
||||
export const DEFAULT_SM_COMPACT_CONFIG: SessionMemoryCompactConfig = {
|
||||
minTokens: 10_000,
|
||||
minTextBlockMessages: 5,
|
||||
maxTokens: 40_000,
|
||||
}
|
||||
|
||||
// Current configuration (starts with defaults)
|
||||
let smCompactConfig: SessionMemoryCompactConfig = {
|
||||
...DEFAULT_SM_COMPACT_CONFIG,
|
||||
}
|
||||
|
||||
// Track whether config has been initialized from remote
|
||||
let configInitialized = false
|
||||
|
||||
/**
|
||||
* Set the session memory compact configuration
|
||||
*/
|
||||
export function setSessionMemoryCompactConfig(
|
||||
config: Partial<SessionMemoryCompactConfig>,
|
||||
): void {
|
||||
smCompactConfig = {
|
||||
...smCompactConfig,
|
||||
...config,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current session memory compact configuration
|
||||
*/
|
||||
export function getSessionMemoryCompactConfig(): SessionMemoryCompactConfig {
|
||||
return { ...smCompactConfig }
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset config state (useful for testing)
|
||||
*/
|
||||
export function resetSessionMemoryCompactConfig(): void {
|
||||
smCompactConfig = { ...DEFAULT_SM_COMPACT_CONFIG }
|
||||
configInitialized = false
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize configuration from remote config (GrowthBook).
|
||||
* Only fetches once per session - subsequent calls return immediately.
|
||||
*/
|
||||
async function initSessionMemoryCompactConfig(): Promise<void> {
|
||||
if (configInitialized) {
|
||||
return
|
||||
}
|
||||
configInitialized = true
|
||||
|
||||
// Load config from GrowthBook, merging with defaults
|
||||
const remoteConfig = await getDynamicConfig_BLOCKS_ON_INIT<
|
||||
Partial<SessionMemoryCompactConfig>
|
||||
>('tengu_sm_compact_config', {})
|
||||
|
||||
// Only use remote values if they are explicitly set (positive numbers)
|
||||
// This ensures sensible defaults aren't overridden by zero values
|
||||
const config: SessionMemoryCompactConfig = {
|
||||
minTokens:
|
||||
remoteConfig.minTokens && remoteConfig.minTokens > 0
|
||||
? remoteConfig.minTokens
|
||||
: DEFAULT_SM_COMPACT_CONFIG.minTokens,
|
||||
minTextBlockMessages:
|
||||
remoteConfig.minTextBlockMessages && remoteConfig.minTextBlockMessages > 0
|
||||
? remoteConfig.minTextBlockMessages
|
||||
: DEFAULT_SM_COMPACT_CONFIG.minTextBlockMessages,
|
||||
maxTokens:
|
||||
remoteConfig.maxTokens && remoteConfig.maxTokens > 0
|
||||
? remoteConfig.maxTokens
|
||||
: DEFAULT_SM_COMPACT_CONFIG.maxTokens,
|
||||
}
|
||||
setSessionMemoryCompactConfig(config)
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a message contains text blocks (text content for user/assistant interaction)
|
||||
*/
|
||||
export function hasTextBlocks(message: Message): boolean {
|
||||
if (message.type === 'assistant') {
|
||||
const content = message.message.content
|
||||
return content.some(block => block.type === 'text')
|
||||
}
|
||||
if (message.type === 'user') {
|
||||
const content = message.message.content
|
||||
if (typeof content === 'string') {
|
||||
return content.length > 0
|
||||
}
|
||||
if (Array.isArray(content)) {
|
||||
return content.some(block => block.type === 'text')
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a message contains tool_result blocks and return their tool_use_ids
|
||||
*/
|
||||
function getToolResultIds(message: Message): string[] {
|
||||
if (message.type !== 'user') {
|
||||
return []
|
||||
}
|
||||
const content = message.message.content
|
||||
if (!Array.isArray(content)) {
|
||||
return []
|
||||
}
|
||||
const ids: string[] = []
|
||||
for (const block of content) {
|
||||
if (block.type === 'tool_result') {
|
||||
ids.push(block.tool_use_id)
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a message contains tool_use blocks with any of the given ids
|
||||
*/
|
||||
function hasToolUseWithIds(message: Message, toolUseIds: Set<string>): boolean {
|
||||
if (message.type !== 'assistant') {
|
||||
return false
|
||||
}
|
||||
const content = message.message.content
|
||||
if (!Array.isArray(content)) {
|
||||
return false
|
||||
}
|
||||
return content.some(
|
||||
block => block.type === 'tool_use' && toolUseIds.has(block.id),
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Adjust the start index to ensure we don't split tool_use/tool_result pairs
|
||||
* or thinking blocks that share the same message.id with kept assistant messages.
|
||||
*
|
||||
* If ANY message we're keeping contains tool_result blocks, we need to
|
||||
* include the preceding assistant message(s) that contain the matching tool_use blocks.
|
||||
*
|
||||
* Additionally, if ANY assistant message in the kept range has the same message.id
|
||||
* as a preceding assistant message (which may contain thinking blocks), we need to
|
||||
* include those messages so they can be properly merged by normalizeMessagesForAPI.
|
||||
*
|
||||
* This handles the case where streaming yields separate messages per content block
|
||||
* (thinking, tool_use, etc.) with the same message.id but different uuids. If the
|
||||
* startIndex lands on one of these streaming messages, we need to look at ALL kept
|
||||
* messages for tool_results, not just the first one.
|
||||
*
|
||||
* Example bug scenarios this fixes:
|
||||
*
|
||||
* Tool pair scenario:
|
||||
* Session storage (before compaction):
|
||||
* Index N: assistant, message.id: X, content: [thinking]
|
||||
* Index N+1: assistant, message.id: X, content: [tool_use: ORPHAN_ID]
|
||||
* Index N+2: assistant, message.id: X, content: [tool_use: VALID_ID]
|
||||
* Index N+3: user, content: [tool_result: ORPHAN_ID, tool_result: VALID_ID]
|
||||
*
|
||||
* If startIndex = N+2:
|
||||
* - Old code: checked only message N+2 for tool_results, found none, returned N+2
|
||||
* - After slicing and normalizeMessagesForAPI merging by message.id:
|
||||
* msg[1]: assistant with [tool_use: VALID_ID] (ORPHAN tool_use was excluded!)
|
||||
* msg[2]: user with [tool_result: ORPHAN_ID, tool_result: VALID_ID]
|
||||
* - API error: orphan tool_result references non-existent tool_use
|
||||
*
|
||||
* Thinking block scenario:
|
||||
* Session storage (before compaction):
|
||||
* Index N: assistant, message.id: X, content: [thinking]
|
||||
* Index N+1: assistant, message.id: X, content: [tool_use: ID]
|
||||
* Index N+2: user, content: [tool_result: ID]
|
||||
*
|
||||
* If startIndex = N+1:
|
||||
* - Without this fix: thinking block at N is excluded
|
||||
* - After normalizeMessagesForAPI: thinking block is lost (no message to merge with)
|
||||
*
|
||||
* Fixed code: detects that message N+1 has same message.id as N, adjusts to N.
|
||||
*/
|
||||
export function adjustIndexToPreserveAPIInvariants(
|
||||
messages: Message[],
|
||||
startIndex: number,
|
||||
): number {
|
||||
if (startIndex <= 0 || startIndex >= messages.length) {
|
||||
return startIndex
|
||||
}
|
||||
|
||||
let adjustedIndex = startIndex
|
||||
|
||||
// Step 1: Handle tool_use/tool_result pairs
|
||||
// Collect tool_result IDs from ALL messages in the kept range
|
||||
const allToolResultIds: string[] = []
|
||||
for (let i = startIndex; i < messages.length; i++) {
|
||||
allToolResultIds.push(...getToolResultIds(messages[i]!))
|
||||
}
|
||||
|
||||
if (allToolResultIds.length > 0) {
|
||||
// Collect tool_use IDs already in the kept range
|
||||
const toolUseIdsInKeptRange = new Set<string>()
|
||||
for (let i = adjustedIndex; i < messages.length; i++) {
|
||||
const msg = messages[i]!
|
||||
if (msg.type === 'assistant' && Array.isArray(msg.message.content)) {
|
||||
for (const block of msg.message.content) {
|
||||
if (block.type === 'tool_use') {
|
||||
toolUseIdsInKeptRange.add(block.id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Only look for tool_uses that are NOT already in the kept range
|
||||
const neededToolUseIds = new Set(
|
||||
allToolResultIds.filter(id => !toolUseIdsInKeptRange.has(id)),
|
||||
)
|
||||
|
||||
// Find the assistant message(s) with matching tool_use blocks
|
||||
for (let i = adjustedIndex - 1; i >= 0 && neededToolUseIds.size > 0; i--) {
|
||||
const message = messages[i]!
|
||||
if (hasToolUseWithIds(message, neededToolUseIds)) {
|
||||
adjustedIndex = i
|
||||
// Remove found tool_use_ids from the set
|
||||
if (
|
||||
message.type === 'assistant' &&
|
||||
Array.isArray(message.message.content)
|
||||
) {
|
||||
for (const block of message.message.content) {
|
||||
if (block.type === 'tool_use' && neededToolUseIds.has(block.id)) {
|
||||
neededToolUseIds.delete(block.id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Handle thinking blocks that share message.id with kept assistant messages
|
||||
// Collect all message.ids from assistant messages in the kept range
|
||||
const messageIdsInKeptRange = new Set<string>()
|
||||
for (let i = adjustedIndex; i < messages.length; i++) {
|
||||
const msg = messages[i]!
|
||||
if (msg.type === 'assistant' && msg.message.id) {
|
||||
messageIdsInKeptRange.add(msg.message.id)
|
||||
}
|
||||
}
|
||||
|
||||
// Look backwards for assistant messages with the same message.id that are not in the kept range
|
||||
// These may contain thinking blocks that need to be merged by normalizeMessagesForAPI
|
||||
for (let i = adjustedIndex - 1; i >= 0; i--) {
|
||||
const message = messages[i]!
|
||||
if (
|
||||
message.type === 'assistant' &&
|
||||
message.message.id &&
|
||||
messageIdsInKeptRange.has(message.message.id)
|
||||
) {
|
||||
// This message has the same message.id as one in the kept range
|
||||
// Include it so thinking blocks can be properly merged
|
||||
adjustedIndex = i
|
||||
}
|
||||
}
|
||||
|
||||
return adjustedIndex
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate the starting index for messages to keep after compaction.
|
||||
* Starts from lastSummarizedMessageId, then expands backwards to meet minimums:
|
||||
* - At least config.minTokens tokens
|
||||
* - At least config.minTextBlockMessages messages with text blocks
|
||||
* Stops expanding if config.maxTokens is reached.
|
||||
* Also ensures tool_use/tool_result pairs are not split.
|
||||
*/
|
||||
export function calculateMessagesToKeepIndex(
|
||||
messages: Message[],
|
||||
lastSummarizedIndex: number,
|
||||
): number {
|
||||
if (messages.length === 0) {
|
||||
return 0
|
||||
}
|
||||
|
||||
const config = getSessionMemoryCompactConfig()
|
||||
|
||||
// Start from the message after lastSummarizedIndex
|
||||
// If lastSummarizedIndex is -1 (not found) or messages.length (no summarized id),
|
||||
// we start with no messages kept
|
||||
let startIndex =
|
||||
lastSummarizedIndex >= 0 ? lastSummarizedIndex + 1 : messages.length
|
||||
|
||||
// Calculate current tokens and text-block message count from startIndex to end
|
||||
let totalTokens = 0
|
||||
let textBlockMessageCount = 0
|
||||
for (let i = startIndex; i < messages.length; i++) {
|
||||
const msg = messages[i]!
|
||||
totalTokens += estimateMessageTokens([msg])
|
||||
if (hasTextBlocks(msg)) {
|
||||
textBlockMessageCount++
|
||||
}
|
||||
}
|
||||
|
||||
// Check if we already hit the max cap
|
||||
if (totalTokens >= config.maxTokens) {
|
||||
return adjustIndexToPreserveAPIInvariants(messages, startIndex)
|
||||
}
|
||||
|
||||
// Check if we already meet both minimums
|
||||
if (
|
||||
totalTokens >= config.minTokens &&
|
||||
textBlockMessageCount >= config.minTextBlockMessages
|
||||
) {
|
||||
return adjustIndexToPreserveAPIInvariants(messages, startIndex)
|
||||
}
|
||||
|
||||
// Expand backwards until we meet both minimums or hit max cap.
|
||||
// Floor at the last boundary: the preserved-segment chain has a disk
|
||||
// discontinuity there (att[0]→summary shortcut from dedup-skip), which
|
||||
// would let the loader's tail→head walk bypass inner preserved messages
|
||||
// and then prune them. Reactive compact already slices at the boundary
|
||||
// via getMessagesAfterCompactBoundary; this is the same invariant.
|
||||
const idx = messages.findLastIndex(m => isCompactBoundaryMessage(m))
|
||||
const floor = idx === -1 ? 0 : idx + 1
|
||||
for (let i = startIndex - 1; i >= floor; i--) {
|
||||
const msg = messages[i]!
|
||||
const msgTokens = estimateMessageTokens([msg])
|
||||
totalTokens += msgTokens
|
||||
if (hasTextBlocks(msg)) {
|
||||
textBlockMessageCount++
|
||||
}
|
||||
startIndex = i
|
||||
|
||||
// Stop if we hit the max cap
|
||||
if (totalTokens >= config.maxTokens) {
|
||||
break
|
||||
}
|
||||
|
||||
// Stop if we meet both minimums
|
||||
if (
|
||||
totalTokens >= config.minTokens &&
|
||||
textBlockMessageCount >= config.minTextBlockMessages
|
||||
) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Adjust for tool pairs
|
||||
return adjustIndexToPreserveAPIInvariants(messages, startIndex)
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if we should use session memory for compaction
|
||||
* Uses cached gate values to avoid blocking on Statsig initialization
|
||||
*/
|
||||
export function shouldUseSessionMemoryCompaction(): boolean {
|
||||
// Allow env var override for eval runs and testing
|
||||
if (isEnvTruthy(process.env.ENABLE_CLAUDE_CODE_SM_COMPACT)) {
|
||||
return true
|
||||
}
|
||||
if (isEnvTruthy(process.env.DISABLE_CLAUDE_CODE_SM_COMPACT)) {
|
||||
return false
|
||||
}
|
||||
|
||||
const sessionMemoryFlag = getFeatureValue_CACHED_MAY_BE_STALE(
|
||||
'tengu_session_memory',
|
||||
false,
|
||||
)
|
||||
const smCompactFlag = getFeatureValue_CACHED_MAY_BE_STALE(
|
||||
'tengu_sm_compact',
|
||||
false,
|
||||
)
|
||||
const shouldUse = sessionMemoryFlag && smCompactFlag
|
||||
|
||||
// Log flag states for debugging (ant-only to avoid noise in external logs)
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
logEvent('tengu_sm_compact_flag_check', {
|
||||
tengu_session_memory: sessionMemoryFlag,
|
||||
tengu_sm_compact: smCompactFlag,
|
||||
should_use: shouldUse,
|
||||
})
|
||||
}
|
||||
|
||||
return shouldUse
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a CompactionResult from session memory
|
||||
*/
|
||||
function createCompactionResultFromSessionMemory(
|
||||
messages: Message[],
|
||||
sessionMemory: string,
|
||||
messagesToKeep: Message[],
|
||||
hookResults: HookResultMessage[],
|
||||
transcriptPath: string,
|
||||
agentId?: AgentId,
|
||||
): CompactionResult {
|
||||
const preCompactTokenCount = tokenCountFromLastAPIResponse(messages)
|
||||
|
||||
const boundaryMarker = createCompactBoundaryMessage(
|
||||
'auto',
|
||||
preCompactTokenCount ?? 0,
|
||||
messages[messages.length - 1]?.uuid,
|
||||
)
|
||||
const preCompactDiscovered = extractDiscoveredToolNames(messages)
|
||||
if (preCompactDiscovered.size > 0) {
|
||||
boundaryMarker.compactMetadata.preCompactDiscoveredTools = [
|
||||
...preCompactDiscovered,
|
||||
].sort()
|
||||
}
|
||||
|
||||
// Truncate oversized sections to prevent session memory from consuming
|
||||
// the entire post-compact token budget
|
||||
const { truncatedContent, wasTruncated } =
|
||||
truncateSessionMemoryForCompact(sessionMemory)
|
||||
|
||||
let summaryContent = getCompactUserSummaryMessage(
|
||||
truncatedContent,
|
||||
true,
|
||||
transcriptPath,
|
||||
true,
|
||||
)
|
||||
|
||||
if (wasTruncated) {
|
||||
const memoryPath = getSessionMemoryPath()
|
||||
summaryContent += `\n\nSome session memory sections were truncated for length. The full session memory can be viewed at: ${memoryPath}`
|
||||
}
|
||||
|
||||
const summaryMessages = [
|
||||
createUserMessage({
|
||||
content: summaryContent,
|
||||
isCompactSummary: true,
|
||||
isVisibleInTranscriptOnly: true,
|
||||
}),
|
||||
]
|
||||
|
||||
const planAttachment = createPlanAttachmentIfNeeded(agentId)
|
||||
const attachments = planAttachment ? [planAttachment] : []
|
||||
|
||||
return {
|
||||
boundaryMarker: annotateBoundaryWithPreservedSegment(
|
||||
boundaryMarker,
|
||||
summaryMessages[summaryMessages.length - 1]!.uuid,
|
||||
messagesToKeep,
|
||||
),
|
||||
summaryMessages,
|
||||
attachments,
|
||||
hookResults,
|
||||
messagesToKeep,
|
||||
preCompactTokenCount,
|
||||
// SM-compact has no compact-API-call, so postCompactTokenCount (kept for
|
||||
// event continuity) and truePostCompactTokenCount converge to the same value.
|
||||
postCompactTokenCount: estimateMessageTokens(summaryMessages),
|
||||
truePostCompactTokenCount: estimateMessageTokens(summaryMessages),
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Try to use session memory for compaction instead of traditional compaction.
|
||||
* Returns null if session memory compaction cannot be used.
|
||||
*
|
||||
* Handles two scenarios:
|
||||
* 1. Normal case: lastSummarizedMessageId is set, keep only messages after that ID
|
||||
* 2. Resumed session: lastSummarizedMessageId is not set but session memory has content,
|
||||
* keep all messages but use session memory as the summary
|
||||
*/
|
||||
export async function trySessionMemoryCompaction(
|
||||
messages: Message[],
|
||||
agentId?: AgentId,
|
||||
autoCompactThreshold?: number,
|
||||
): Promise<CompactionResult | null> {
|
||||
if (!shouldUseSessionMemoryCompaction()) {
|
||||
return null
|
||||
}
|
||||
|
||||
// Initialize config from remote (only fetches once)
|
||||
await initSessionMemoryCompactConfig()
|
||||
|
||||
// Wait for any in-progress session memory extraction to complete (with timeout)
|
||||
await waitForSessionMemoryExtraction()
|
||||
|
||||
const lastSummarizedMessageId = getLastSummarizedMessageId()
|
||||
const sessionMemory = await getSessionMemoryContent()
|
||||
|
||||
// No session memory file exists at all
|
||||
if (!sessionMemory) {
|
||||
logEvent('tengu_sm_compact_no_session_memory', {})
|
||||
return null
|
||||
}
|
||||
|
||||
// Session memory exists but matches the template (no actual content extracted)
|
||||
// Fall back to legacy compact behavior
|
||||
if (await isSessionMemoryEmpty(sessionMemory)) {
|
||||
logEvent('tengu_sm_compact_empty_template', {})
|
||||
return null
|
||||
}
|
||||
|
||||
try {
|
||||
let lastSummarizedIndex: number
|
||||
|
||||
if (lastSummarizedMessageId) {
|
||||
// Normal case: we know exactly which messages have been summarized
|
||||
lastSummarizedIndex = messages.findIndex(
|
||||
msg => msg.uuid === lastSummarizedMessageId,
|
||||
)
|
||||
|
||||
if (lastSummarizedIndex === -1) {
|
||||
// The summarized message ID doesn't exist in current messages
|
||||
// This can happen if messages were modified - fall back to legacy compact
|
||||
// since we can't determine the boundary between summarized and unsummarized messages
|
||||
logEvent('tengu_sm_compact_summarized_id_not_found', {})
|
||||
return null
|
||||
}
|
||||
} else {
|
||||
// Resumed session case: session memory has content but we don't know the boundary
|
||||
// Set lastSummarizedIndex to last message so startIndex becomes messages.length (no messages kept initially)
|
||||
lastSummarizedIndex = messages.length - 1
|
||||
logEvent('tengu_sm_compact_resumed_session', {})
|
||||
}
|
||||
|
||||
// Calculate the starting index for messages to keep
|
||||
// This starts from lastSummarizedIndex, expands to meet minimums,
|
||||
// and adjusts to not split tool_use/tool_result pairs
|
||||
const startIndex = calculateMessagesToKeepIndex(
|
||||
messages,
|
||||
lastSummarizedIndex,
|
||||
)
|
||||
// Filter out old compact boundary messages from messagesToKeep.
|
||||
// After REPL pruning, old boundaries re-yielded from messagesToKeep would
|
||||
// trigger an unwanted second prune (isCompactBoundaryMessage returns true),
|
||||
// discarding the new boundary and summary.
|
||||
const messagesToKeep = messages
|
||||
.slice(startIndex)
|
||||
.filter(m => !isCompactBoundaryMessage(m))
|
||||
|
||||
// Run session start hooks to restore CLAUDE.md and other context
|
||||
const hookResults = await processSessionStartHooks('compact', {
|
||||
model: getMainLoopModel(),
|
||||
})
|
||||
|
||||
// Get transcript path for the summary message
|
||||
const transcriptPath = getTranscriptPath()
|
||||
|
||||
const compactionResult = createCompactionResultFromSessionMemory(
|
||||
messages,
|
||||
sessionMemory,
|
||||
messagesToKeep,
|
||||
hookResults,
|
||||
transcriptPath,
|
||||
agentId,
|
||||
)
|
||||
|
||||
const postCompactMessages = buildPostCompactMessages(compactionResult)
|
||||
|
||||
const postCompactTokenCount = estimateMessageTokens(postCompactMessages)
|
||||
|
||||
// Only check threshold if one was provided (for autocompact)
|
||||
if (
|
||||
autoCompactThreshold !== undefined &&
|
||||
postCompactTokenCount >= autoCompactThreshold
|
||||
) {
|
||||
logEvent('tengu_sm_compact_threshold_exceeded', {
|
||||
postCompactTokenCount,
|
||||
autoCompactThreshold,
|
||||
})
|
||||
return null
|
||||
}
|
||||
|
||||
return {
|
||||
...compactionResult,
|
||||
postCompactTokenCount,
|
||||
truePostCompactTokenCount: postCompactTokenCount,
|
||||
}
|
||||
} catch (error) {
|
||||
// Use logEvent instead of logError since errors here are expected
|
||||
// (e.g., file not found, path issues) and shouldn't go to error logs
|
||||
logEvent('tengu_sm_compact_error', {})
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
logForDebugging(`Session memory compaction error: ${errorMessage(error)}`)
|
||||
}
|
||||
return null
|
||||
}
|
||||
}
|
||||
43
src/services/compact/timeBasedMCConfig.ts
Normal file
43
src/services/compact/timeBasedMCConfig.ts
Normal file
@@ -0,0 +1,43 @@
|
||||
import { getFeatureValue_CACHED_MAY_BE_STALE } from '../analytics/growthbook.js'
|
||||
|
||||
/**
|
||||
* GrowthBook config for time-based microcompact.
|
||||
*
|
||||
* Triggers content-clearing microcompact when the gap since the last main-loop
|
||||
* assistant message exceeds a threshold — the server-side prompt cache has
|
||||
* almost certainly expired, so the full prefix will be rewritten anyway.
|
||||
* Clearing old tool results before the request shrinks what gets rewritten.
|
||||
*
|
||||
* Runs BEFORE the API call (in microcompactMessages, upstream of callModel)
|
||||
* so the shrunk prompt is what actually gets sent. Running after the first
|
||||
* miss would only help subsequent turns.
|
||||
*
|
||||
* Main thread only — subagents have short lifetimes where gap-based eviction
|
||||
* doesn't apply.
|
||||
*/
|
||||
export type TimeBasedMCConfig = {
|
||||
/** Master switch. When false, time-based microcompact is a no-op. */
|
||||
enabled: boolean
|
||||
/** Trigger when (now − last assistant timestamp) exceeds this many minutes.
|
||||
* 60 is the safe choice: the server's 1h cache TTL is guaranteed expired
|
||||
* for all users, so we never force a miss that wouldn't have happened. */
|
||||
gapThresholdMinutes: number
|
||||
/** Keep this many most-recent compactable tool results.
|
||||
* When set, takes priority over any default; older results are cleared. */
|
||||
keepRecent: number
|
||||
}
|
||||
|
||||
const TIME_BASED_MC_CONFIG_DEFAULTS: TimeBasedMCConfig = {
|
||||
enabled: false,
|
||||
gapThresholdMinutes: 60,
|
||||
keepRecent: 5,
|
||||
}
|
||||
|
||||
export function getTimeBasedMCConfig(): TimeBasedMCConfig {
|
||||
// Hoist the GB read so exposure fires on every eval path, not just when
|
||||
// the caller's other conditions (querySource, messages.length) pass.
|
||||
return getFeatureValue_CACHED_MAY_BE_STALE<TimeBasedMCConfig>(
|
||||
'tengu_slate_heron',
|
||||
TIME_BASED_MC_CONFIG_DEFAULTS,
|
||||
)
|
||||
}
|
||||
397
src/services/diagnosticTracking.ts
Normal file
397
src/services/diagnosticTracking.ts
Normal file
@@ -0,0 +1,397 @@
|
||||
import figures from 'figures'
|
||||
import { logError } from 'src/utils/log.js'
|
||||
import { callIdeRpc } from '../services/mcp/client.js'
|
||||
import type { MCPServerConnection } from '../services/mcp/types.js'
|
||||
import { ClaudeError } from '../utils/errors.js'
|
||||
import { normalizePathForComparison, pathsEqual } from '../utils/file.js'
|
||||
import { getConnectedIdeClient } from '../utils/ide.js'
|
||||
import { jsonParse } from '../utils/slowOperations.js'
|
||||
|
||||
class DiagnosticsTrackingError extends ClaudeError {}
|
||||
|
||||
const MAX_DIAGNOSTICS_SUMMARY_CHARS = 4000
|
||||
|
||||
export interface Diagnostic {
|
||||
message: string
|
||||
severity: 'Error' | 'Warning' | 'Info' | 'Hint'
|
||||
range: {
|
||||
start: { line: number; character: number }
|
||||
end: { line: number; character: number }
|
||||
}
|
||||
source?: string
|
||||
code?: string
|
||||
}
|
||||
|
||||
export interface DiagnosticFile {
|
||||
uri: string
|
||||
diagnostics: Diagnostic[]
|
||||
}
|
||||
|
||||
export class DiagnosticTrackingService {
|
||||
private static instance: DiagnosticTrackingService | undefined
|
||||
private baseline: Map<string, Diagnostic[]> = new Map()
|
||||
|
||||
private initialized = false
|
||||
private mcpClient: MCPServerConnection | undefined
|
||||
|
||||
// Track when files were last processed/fetched
|
||||
private lastProcessedTimestamps: Map<string, number> = new Map()
|
||||
|
||||
// Track which files have received right file diagnostics and if they've changed
|
||||
// Map<normalizedPath, lastClaudeFsRightDiagnostics>
|
||||
private rightFileDiagnosticsState: Map<string, Diagnostic[]> = new Map()
|
||||
|
||||
static getInstance(): DiagnosticTrackingService {
|
||||
if (!DiagnosticTrackingService.instance) {
|
||||
DiagnosticTrackingService.instance = new DiagnosticTrackingService()
|
||||
}
|
||||
return DiagnosticTrackingService.instance
|
||||
}
|
||||
|
||||
initialize(mcpClient: MCPServerConnection) {
|
||||
if (this.initialized) {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: Do not cache the connected mcpClient since it can change.
|
||||
this.mcpClient = mcpClient
|
||||
this.initialized = true
|
||||
}
|
||||
|
||||
async shutdown(): Promise<void> {
|
||||
this.initialized = false
|
||||
this.baseline.clear()
|
||||
this.rightFileDiagnosticsState.clear()
|
||||
this.lastProcessedTimestamps.clear()
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset tracking state while keeping the service initialized.
|
||||
* This clears all tracked files and diagnostics.
|
||||
*/
|
||||
reset() {
|
||||
this.baseline.clear()
|
||||
this.rightFileDiagnosticsState.clear()
|
||||
this.lastProcessedTimestamps.clear()
|
||||
}
|
||||
|
||||
private normalizeFileUri(fileUri: string): string {
|
||||
// Remove our protocol prefixes
|
||||
const protocolPrefixes = [
|
||||
'file://',
|
||||
'_claude_fs_right:',
|
||||
'_claude_fs_left:',
|
||||
]
|
||||
|
||||
let normalized = fileUri
|
||||
for (const prefix of protocolPrefixes) {
|
||||
if (fileUri.startsWith(prefix)) {
|
||||
normalized = fileUri.slice(prefix.length)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Use shared utility for platform-aware path normalization
|
||||
// (handles Windows case-insensitivity and path separators)
|
||||
return normalizePathForComparison(normalized)
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensure a file is opened in the IDE before processing.
|
||||
* This is important for language services like diagnostics to work properly.
|
||||
*/
|
||||
async ensureFileOpened(fileUri: string): Promise<void> {
|
||||
if (
|
||||
!this.initialized ||
|
||||
!this.mcpClient ||
|
||||
this.mcpClient.type !== 'connected'
|
||||
) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
// Call the openFile tool to ensure the file is loaded
|
||||
await callIdeRpc(
|
||||
'openFile',
|
||||
{
|
||||
filePath: fileUri,
|
||||
preview: false,
|
||||
startText: '',
|
||||
endText: '',
|
||||
selectToEndOfLine: false,
|
||||
makeFrontmost: false,
|
||||
},
|
||||
this.mcpClient,
|
||||
)
|
||||
} catch (error) {
|
||||
logError(error as Error)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Capture baseline diagnostics for a specific file before editing.
|
||||
* This is called before editing a file to ensure we have a baseline to compare against.
|
||||
*/
|
||||
async beforeFileEdited(filePath: string): Promise<void> {
|
||||
if (
|
||||
!this.initialized ||
|
||||
!this.mcpClient ||
|
||||
this.mcpClient.type !== 'connected'
|
||||
) {
|
||||
return
|
||||
}
|
||||
|
||||
const timestamp = Date.now()
|
||||
|
||||
try {
|
||||
const result = await callIdeRpc(
|
||||
'getDiagnostics',
|
||||
{ uri: `file://${filePath}` },
|
||||
this.mcpClient,
|
||||
)
|
||||
const diagnosticFile = this.parseDiagnosticResult(result)[0]
|
||||
if (diagnosticFile) {
|
||||
// Compare normalized paths (handles protocol prefixes and Windows case-insensitivity)
|
||||
if (
|
||||
!pathsEqual(
|
||||
this.normalizeFileUri(filePath),
|
||||
this.normalizeFileUri(diagnosticFile.uri),
|
||||
)
|
||||
) {
|
||||
logError(
|
||||
new DiagnosticsTrackingError(
|
||||
`Diagnostics file path mismatch: expected ${filePath}, got ${diagnosticFile.uri})`,
|
||||
),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Store with normalized path key for consistent lookups on Windows
|
||||
const normalizedPath = this.normalizeFileUri(filePath)
|
||||
this.baseline.set(normalizedPath, diagnosticFile.diagnostics)
|
||||
this.lastProcessedTimestamps.set(normalizedPath, timestamp)
|
||||
} else {
|
||||
// No diagnostic file returned, store an empty baseline
|
||||
const normalizedPath = this.normalizeFileUri(filePath)
|
||||
this.baseline.set(normalizedPath, [])
|
||||
this.lastProcessedTimestamps.set(normalizedPath, timestamp)
|
||||
}
|
||||
} catch (_error) {
|
||||
// Fail silently if IDE doesn't support diagnostics
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get new diagnostics from file://, _claude_fs_right, and _claude_fs_ URIs that aren't in the baseline.
|
||||
* Only processes diagnostics for files that have been edited.
|
||||
*/
|
||||
async getNewDiagnostics(): Promise<DiagnosticFile[]> {
|
||||
if (
|
||||
!this.initialized ||
|
||||
!this.mcpClient ||
|
||||
this.mcpClient.type !== 'connected'
|
||||
) {
|
||||
return []
|
||||
}
|
||||
|
||||
// Check if we have any files with diagnostic changes
|
||||
let allDiagnosticFiles: DiagnosticFile[] = []
|
||||
try {
|
||||
const result = await callIdeRpc(
|
||||
'getDiagnostics',
|
||||
{}, // Empty params fetches all diagnostics
|
||||
this.mcpClient,
|
||||
)
|
||||
allDiagnosticFiles = this.parseDiagnosticResult(result)
|
||||
} catch (_error) {
|
||||
// If fetching all diagnostics fails, return empty
|
||||
return []
|
||||
}
|
||||
const diagnosticsForFileUrisWithBaselines = allDiagnosticFiles
|
||||
.filter(file => this.baseline.has(this.normalizeFileUri(file.uri)))
|
||||
.filter(file => file.uri.startsWith('file://'))
|
||||
|
||||
const diagnosticsForClaudeFsRightUrisWithBaselinesMap = new Map<
|
||||
string,
|
||||
DiagnosticFile
|
||||
>()
|
||||
allDiagnosticFiles
|
||||
.filter(file => this.baseline.has(this.normalizeFileUri(file.uri)))
|
||||
.filter(file => file.uri.startsWith('_claude_fs_right:'))
|
||||
.forEach(file => {
|
||||
diagnosticsForClaudeFsRightUrisWithBaselinesMap.set(
|
||||
this.normalizeFileUri(file.uri),
|
||||
file,
|
||||
)
|
||||
})
|
||||
|
||||
const newDiagnosticFiles: DiagnosticFile[] = []
|
||||
|
||||
// Process file:// protocol diagnostics
|
||||
for (const file of diagnosticsForFileUrisWithBaselines) {
|
||||
const normalizedPath = this.normalizeFileUri(file.uri)
|
||||
const baselineDiagnostics = this.baseline.get(normalizedPath) || []
|
||||
|
||||
// Get the _claude_fs_right file if it exists
|
||||
const claudeFsRightFile =
|
||||
diagnosticsForClaudeFsRightUrisWithBaselinesMap.get(normalizedPath)
|
||||
|
||||
// Determine which file to use based on the state of right file diagnostics
|
||||
let fileToUse = file
|
||||
|
||||
if (claudeFsRightFile) {
|
||||
const previousRightDiagnostics =
|
||||
this.rightFileDiagnosticsState.get(normalizedPath)
|
||||
|
||||
// Use _claude_fs_right if:
|
||||
// 1. We've never gotten right file diagnostics for this file (previousRightDiagnostics === undefined)
|
||||
// 2. OR the right file diagnostics have just changed
|
||||
if (
|
||||
!previousRightDiagnostics ||
|
||||
!this.areDiagnosticArraysEqual(
|
||||
previousRightDiagnostics,
|
||||
claudeFsRightFile.diagnostics,
|
||||
)
|
||||
) {
|
||||
fileToUse = claudeFsRightFile
|
||||
}
|
||||
|
||||
// Update our tracking of right file diagnostics
|
||||
this.rightFileDiagnosticsState.set(
|
||||
normalizedPath,
|
||||
claudeFsRightFile.diagnostics,
|
||||
)
|
||||
}
|
||||
|
||||
// Find new diagnostics that aren't in the baseline
|
||||
const newDiagnostics = fileToUse.diagnostics.filter(
|
||||
d => !baselineDiagnostics.some(b => this.areDiagnosticsEqual(d, b)),
|
||||
)
|
||||
|
||||
if (newDiagnostics.length > 0) {
|
||||
newDiagnosticFiles.push({
|
||||
uri: file.uri,
|
||||
diagnostics: newDiagnostics,
|
||||
})
|
||||
}
|
||||
|
||||
// Update baseline with current diagnostics
|
||||
this.baseline.set(normalizedPath, fileToUse.diagnostics)
|
||||
}
|
||||
|
||||
return newDiagnosticFiles
|
||||
}
|
||||
|
||||
private parseDiagnosticResult(result: unknown): DiagnosticFile[] {
|
||||
if (Array.isArray(result)) {
|
||||
const textBlock = result.find(block => block.type === 'text')
|
||||
if (textBlock && 'text' in textBlock) {
|
||||
const parsed = jsonParse(textBlock.text)
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
return []
|
||||
}
|
||||
|
||||
private areDiagnosticsEqual(a: Diagnostic, b: Diagnostic): boolean {
|
||||
return (
|
||||
a.message === b.message &&
|
||||
a.severity === b.severity &&
|
||||
a.source === b.source &&
|
||||
a.code === b.code &&
|
||||
a.range.start.line === b.range.start.line &&
|
||||
a.range.start.character === b.range.start.character &&
|
||||
a.range.end.line === b.range.end.line &&
|
||||
a.range.end.character === b.range.end.character
|
||||
)
|
||||
}
|
||||
|
||||
private areDiagnosticArraysEqual(a: Diagnostic[], b: Diagnostic[]): boolean {
|
||||
if (a.length !== b.length) return false
|
||||
|
||||
// Check if every diagnostic in 'a' exists in 'b'
|
||||
return (
|
||||
a.every(diagA =>
|
||||
b.some(diagB => this.areDiagnosticsEqual(diagA, diagB)),
|
||||
) &&
|
||||
b.every(diagB => a.some(diagA => this.areDiagnosticsEqual(diagA, diagB)))
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle the start of a new query. This method:
|
||||
* - Initializes the diagnostic tracker if not already initialized
|
||||
* - Resets the tracker if already initialized (for new query loops)
|
||||
* - Automatically finds the IDE client from the provided clients list
|
||||
*
|
||||
* @param clients Array of MCP clients that may include an IDE client
|
||||
* @param shouldQuery Whether a query is actually being made (not just a command)
|
||||
*/
|
||||
async handleQueryStart(clients: MCPServerConnection[]): Promise<void> {
|
||||
// Only proceed if we should query and have clients
|
||||
if (!this.initialized) {
|
||||
// Find the connected IDE client
|
||||
const connectedIdeClient = getConnectedIdeClient(clients)
|
||||
|
||||
if (connectedIdeClient) {
|
||||
this.initialize(connectedIdeClient)
|
||||
}
|
||||
} else {
|
||||
// Reset diagnostic tracking for new query loops
|
||||
this.reset()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Format diagnostics into a human-readable summary string.
|
||||
* This is useful for displaying diagnostics in messages or logs.
|
||||
*
|
||||
* @param files Array of diagnostic files to format
|
||||
* @returns Formatted string representation of the diagnostics
|
||||
*/
|
||||
static formatDiagnosticsSummary(files: DiagnosticFile[]): string {
|
||||
const truncationMarker = '…[truncated]'
|
||||
const result = files
|
||||
.map(file => {
|
||||
const filename = file.uri.split('/').pop() || file.uri
|
||||
const diagnostics = file.diagnostics
|
||||
.map(d => {
|
||||
const severitySymbol = DiagnosticTrackingService.getSeveritySymbol(
|
||||
d.severity,
|
||||
)
|
||||
|
||||
return ` ${severitySymbol} [Line ${d.range.start.line + 1}:${d.range.start.character + 1}] ${d.message}${d.code ? ` [${d.code}]` : ''}${d.source ? ` (${d.source})` : ''}`
|
||||
})
|
||||
.join('\n')
|
||||
|
||||
return `${filename}:\n${diagnostics}`
|
||||
})
|
||||
.join('\n\n')
|
||||
|
||||
if (result.length > MAX_DIAGNOSTICS_SUMMARY_CHARS) {
|
||||
return (
|
||||
result.slice(
|
||||
0,
|
||||
MAX_DIAGNOSTICS_SUMMARY_CHARS - truncationMarker.length,
|
||||
) + truncationMarker
|
||||
)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the severity symbol for a diagnostic
|
||||
*/
|
||||
static getSeveritySymbol(severity: Diagnostic['severity']): string {
|
||||
return (
|
||||
{
|
||||
Error: figures.cross,
|
||||
Warning: figures.warning,
|
||||
Info: figures.info,
|
||||
Hint: figures.star,
|
||||
}[severity] || figures.bullet
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
export const diagnosticTracker = DiagnosticTrackingService.getInstance()
|
||||
615
src/services/extractMemories/extractMemories.ts
Normal file
615
src/services/extractMemories/extractMemories.ts
Normal file
@@ -0,0 +1,615 @@
|
||||
/**
|
||||
* Extracts durable memories from the current session transcript
|
||||
* and writes them to the auto-memory directory (~/.claude/projects/<path>/memory/).
|
||||
*
|
||||
* It runs once at the end of each complete query loop (when the model produces
|
||||
* a final response with no tool calls) via handleStopHooks in stopHooks.ts.
|
||||
*
|
||||
* Uses the forked agent pattern (runForkedAgent) — a perfect fork of the main
|
||||
* conversation that shares the parent's prompt cache.
|
||||
*
|
||||
* State is closure-scoped inside initExtractMemories() rather than module-level,
|
||||
* following the same pattern as confidenceRating.ts. Tests call
|
||||
* initExtractMemories() in beforeEach to get a fresh closure.
|
||||
*/
|
||||
|
||||
import { feature } from 'bun:bundle'
|
||||
import { basename } from 'path'
|
||||
import { getIsRemoteMode } from '../../bootstrap/state.js'
|
||||
import type { CanUseToolFn } from '../../hooks/useCanUseTool.js'
|
||||
import { ENTRYPOINT_NAME } from '../../memdir/memdir.js'
|
||||
import {
|
||||
formatMemoryManifest,
|
||||
scanMemoryFiles,
|
||||
} from '../../memdir/memoryScan.js'
|
||||
import {
|
||||
getAutoMemPath,
|
||||
isAutoMemoryEnabled,
|
||||
isAutoMemPath,
|
||||
} from '../../memdir/paths.js'
|
||||
import type { Tool } from '../../Tool.js'
|
||||
import { BASH_TOOL_NAME } from '../../tools/BashTool/toolName.js'
|
||||
import { FILE_EDIT_TOOL_NAME } from '../../tools/FileEditTool/constants.js'
|
||||
import { FILE_READ_TOOL_NAME } from '../../tools/FileReadTool/prompt.js'
|
||||
import { FILE_WRITE_TOOL_NAME } from '../../tools/FileWriteTool/prompt.js'
|
||||
import { GLOB_TOOL_NAME } from '../../tools/GlobTool/prompt.js'
|
||||
import { GREP_TOOL_NAME } from '../../tools/GrepTool/prompt.js'
|
||||
import { REPL_TOOL_NAME } from '../../tools/REPLTool/constants.js'
|
||||
import type {
|
||||
AssistantMessage,
|
||||
Message,
|
||||
SystemLocalCommandMessage,
|
||||
SystemMessage,
|
||||
} from '../../types/message.js'
|
||||
import { createAbortController } from '../../utils/abortController.js'
|
||||
import { count, uniq } from '../../utils/array.js'
|
||||
import { logForDebugging } from '../../utils/debug.js'
|
||||
import {
|
||||
createCacheSafeParams,
|
||||
runForkedAgent,
|
||||
} from '../../utils/forkedAgent.js'
|
||||
import type { REPLHookContext } from '../../utils/hooks/postSamplingHooks.js'
|
||||
import {
|
||||
createMemorySavedMessage,
|
||||
createUserMessage,
|
||||
} from '../../utils/messages.js'
|
||||
import { getFeatureValue_CACHED_MAY_BE_STALE } from '../analytics/growthbook.js'
|
||||
import { logEvent } from '../analytics/index.js'
|
||||
import { sanitizeToolNameForAnalytics } from '../analytics/metadata.js'
|
||||
import {
|
||||
buildExtractAutoOnlyPrompt,
|
||||
buildExtractCombinedPrompt,
|
||||
} from './prompts.js'
|
||||
|
||||
/* eslint-disable @typescript-eslint/no-require-imports */
|
||||
const teamMemPaths = feature('TEAMMEM')
|
||||
? (require('../../memdir/teamMemPaths.js') as typeof import('../../memdir/teamMemPaths.js'))
|
||||
: null
|
||||
/* eslint-enable @typescript-eslint/no-require-imports */
|
||||
|
||||
// ============================================================================
|
||||
// Helpers
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Returns true if a message is visible to the model (sent in API calls).
|
||||
* Excludes progress, system, and attachment messages.
|
||||
*/
|
||||
function isModelVisibleMessage(message: Message): boolean {
|
||||
return message.type === 'user' || message.type === 'assistant'
|
||||
}
|
||||
|
||||
function countModelVisibleMessagesSince(
|
||||
messages: Message[],
|
||||
sinceUuid: string | undefined,
|
||||
): number {
|
||||
if (sinceUuid === null || sinceUuid === undefined) {
|
||||
return count(messages, isModelVisibleMessage)
|
||||
}
|
||||
|
||||
let foundStart = false
|
||||
let n = 0
|
||||
for (const message of messages) {
|
||||
if (!foundStart) {
|
||||
if (message.uuid === sinceUuid) {
|
||||
foundStart = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
if (isModelVisibleMessage(message)) {
|
||||
n++
|
||||
}
|
||||
}
|
||||
// If sinceUuid was not found (e.g., removed by context compaction),
|
||||
// fall back to counting all model-visible messages rather than returning 0
|
||||
// which would permanently disable extraction for the rest of the session.
|
||||
if (!foundStart) {
|
||||
return count(messages, isModelVisibleMessage)
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if any assistant message after the cursor UUID contains a
|
||||
* Write/Edit tool_use block targeting an auto-memory path.
|
||||
*
|
||||
* The main agent's prompt has full save instructions — when it writes
|
||||
* memories, the forked extraction is redundant. runExtraction skips the
|
||||
* agent and advances the cursor past this range, making the main agent
|
||||
* and the background agent mutually exclusive per turn.
|
||||
*/
|
||||
function hasMemoryWritesSince(
|
||||
messages: Message[],
|
||||
sinceUuid: string | undefined,
|
||||
): boolean {
|
||||
let foundStart = sinceUuid === undefined
|
||||
for (const message of messages) {
|
||||
if (!foundStart) {
|
||||
if (message.uuid === sinceUuid) {
|
||||
foundStart = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
if (message.type !== 'assistant') {
|
||||
continue
|
||||
}
|
||||
const content = (message as AssistantMessage).message.content
|
||||
if (!Array.isArray(content)) {
|
||||
continue
|
||||
}
|
||||
for (const block of content) {
|
||||
const filePath = getWrittenFilePath(block)
|
||||
if (filePath !== undefined && isAutoMemPath(filePath)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tool Permissions
|
||||
// ============================================================================
|
||||
|
||||
function denyAutoMemTool(tool: Tool, reason: string) {
|
||||
logForDebugging(`[autoMem] denied ${tool.name}: ${reason}`)
|
||||
logEvent('tengu_auto_mem_tool_denied', {
|
||||
tool_name: sanitizeToolNameForAnalytics(tool.name),
|
||||
})
|
||||
return {
|
||||
behavior: 'deny' as const,
|
||||
message: reason,
|
||||
decisionReason: { type: 'other' as const, reason },
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a canUseTool function that allows Read/Grep/Glob (unrestricted),
|
||||
* read-only Bash commands, and Edit/Write only for paths within the
|
||||
* auto-memory directory. Shared by extractMemories and autoDream.
|
||||
*/
|
||||
export function createAutoMemCanUseTool(memoryDir: string): CanUseToolFn {
|
||||
return async (tool: Tool, input: Record<string, unknown>) => {
|
||||
// Allow REPL — when REPL mode is enabled (ant-default), primitive tools
|
||||
// are hidden from the tool list so the forked agent calls REPL instead.
|
||||
// REPL's VM context re-invokes this canUseTool for each inner primitive
|
||||
// (toolWrappers.ts createToolWrapper), so the Read/Bash/Edit/Write checks
|
||||
// below still gate the actual file and shell operations. Giving the fork a
|
||||
// different tool list would break prompt cache sharing (tools are part of
|
||||
// the cache key — see CacheSafeParams in forkedAgent.ts).
|
||||
if (tool.name === REPL_TOOL_NAME) {
|
||||
return { behavior: 'allow' as const, updatedInput: input }
|
||||
}
|
||||
|
||||
// Allow Read/Grep/Glob unrestricted — all inherently read-only
|
||||
if (
|
||||
tool.name === FILE_READ_TOOL_NAME ||
|
||||
tool.name === GREP_TOOL_NAME ||
|
||||
tool.name === GLOB_TOOL_NAME
|
||||
) {
|
||||
return { behavior: 'allow' as const, updatedInput: input }
|
||||
}
|
||||
|
||||
// Allow Bash only for commands that pass BashTool.isReadOnly.
|
||||
// `tool` IS BashTool here — no static import needed.
|
||||
if (tool.name === BASH_TOOL_NAME) {
|
||||
const parsed = tool.inputSchema.safeParse(input)
|
||||
if (parsed.success && tool.isReadOnly(parsed.data)) {
|
||||
return { behavior: 'allow' as const, updatedInput: input }
|
||||
}
|
||||
return denyAutoMemTool(
|
||||
tool,
|
||||
'Only read-only shell commands are permitted in this context (ls, find, grep, cat, stat, wc, head, tail, and similar)',
|
||||
)
|
||||
}
|
||||
|
||||
if (
|
||||
(tool.name === FILE_EDIT_TOOL_NAME ||
|
||||
tool.name === FILE_WRITE_TOOL_NAME) &&
|
||||
'file_path' in input
|
||||
) {
|
||||
const filePath = input.file_path
|
||||
if (typeof filePath === 'string' && isAutoMemPath(filePath)) {
|
||||
return { behavior: 'allow' as const, updatedInput: input }
|
||||
}
|
||||
}
|
||||
|
||||
return denyAutoMemTool(
|
||||
tool,
|
||||
`only ${FILE_READ_TOOL_NAME}, ${GREP_TOOL_NAME}, ${GLOB_TOOL_NAME}, read-only ${BASH_TOOL_NAME}, and ${FILE_EDIT_TOOL_NAME}/${FILE_WRITE_TOOL_NAME} within ${memoryDir} are allowed`,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Extract file paths from agent output
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Extract file_path from a tool_use block's input, if present.
|
||||
* Returns undefined when the block is not an Edit/Write tool use or has no file_path.
|
||||
*/
|
||||
function getWrittenFilePath(block: {
|
||||
type: string
|
||||
name?: string
|
||||
input?: unknown
|
||||
}): string | undefined {
|
||||
if (
|
||||
block.type !== 'tool_use' ||
|
||||
(block.name !== FILE_EDIT_TOOL_NAME && block.name !== FILE_WRITE_TOOL_NAME)
|
||||
) {
|
||||
return undefined
|
||||
}
|
||||
const input = block.input
|
||||
if (typeof input === 'object' && input !== null && 'file_path' in input) {
|
||||
const fp = (input as { file_path: unknown }).file_path
|
||||
return typeof fp === 'string' ? fp : undefined
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
function extractWrittenPaths(agentMessages: Message[]): string[] {
|
||||
const paths: string[] = []
|
||||
for (const message of agentMessages) {
|
||||
if (message.type !== 'assistant') {
|
||||
continue
|
||||
}
|
||||
const content = (message as AssistantMessage).message.content
|
||||
if (!Array.isArray(content)) {
|
||||
continue
|
||||
}
|
||||
for (const block of content) {
|
||||
const filePath = getWrittenFilePath(block)
|
||||
if (filePath !== undefined) {
|
||||
paths.push(filePath)
|
||||
}
|
||||
}
|
||||
}
|
||||
return uniq(paths)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Initialization & Closure-scoped State
|
||||
// ============================================================================
|
||||
|
||||
type AppendSystemMessageFn = (
|
||||
msg: Exclude<SystemMessage, SystemLocalCommandMessage>,
|
||||
) => void
|
||||
|
||||
/** The active extractor function, set by initExtractMemories(). */
|
||||
let extractor:
|
||||
| ((
|
||||
context: REPLHookContext,
|
||||
appendSystemMessage?: AppendSystemMessageFn,
|
||||
) => Promise<void>)
|
||||
| null = null
|
||||
|
||||
/** The active drain function, set by initExtractMemories(). No-op until init. */
|
||||
let drainer: (timeoutMs?: number) => Promise<void> = async () => {}
|
||||
|
||||
/**
|
||||
* Initialize the memory extraction system.
|
||||
* Creates a fresh closure that captures all mutable state (cursor position,
|
||||
* overlap guard, pending context). Call once at startup alongside
|
||||
* initConfidenceRating/initPromptCoaching, or per-test in beforeEach.
|
||||
*/
|
||||
export function initExtractMemories(): void {
|
||||
// --- Closure-scoped mutable state ---
|
||||
|
||||
/** Every promise handed out by the extractor that hasn't settled yet.
|
||||
* Coalesced calls that stash-and-return add fast-resolving promises
|
||||
* (harmless); the call that starts real work adds a promise covering the
|
||||
* full trailing-run chain via runExtraction's recursive finally. */
|
||||
const inFlightExtractions = new Set<Promise<void>>()
|
||||
|
||||
/** UUID of the last message processed — cursor so each run only
|
||||
* considers messages added since the previous extraction. */
|
||||
let lastMemoryMessageUuid: string | undefined
|
||||
|
||||
/** One-shot flag: once we log that the gate is disabled, don't repeat. */
|
||||
let hasLoggedGateFailure = false
|
||||
|
||||
/** True while runExtraction is executing — prevents overlapping runs. */
|
||||
let inProgress = false
|
||||
|
||||
/** Counts eligible turns since the last extraction run. Resets to 0 after each run. */
|
||||
let turnsSinceLastExtraction = 0
|
||||
|
||||
/** When a call arrives during an in-progress run, we stash the context here
|
||||
* and run one trailing extraction after the current one finishes. */
|
||||
let pendingContext:
|
||||
| {
|
||||
context: REPLHookContext
|
||||
appendSystemMessage?: AppendSystemMessageFn
|
||||
}
|
||||
| undefined
|
||||
|
||||
// --- Inner extraction logic ---
|
||||
|
||||
async function runExtraction({
|
||||
context,
|
||||
appendSystemMessage,
|
||||
isTrailingRun,
|
||||
}: {
|
||||
context: REPLHookContext
|
||||
appendSystemMessage?: AppendSystemMessageFn
|
||||
isTrailingRun?: boolean
|
||||
}): Promise<void> {
|
||||
const { messages } = context
|
||||
const memoryDir = getAutoMemPath()
|
||||
const newMessageCount = countModelVisibleMessagesSince(
|
||||
messages,
|
||||
lastMemoryMessageUuid,
|
||||
)
|
||||
|
||||
// Mutual exclusion: when the main agent wrote memories, skip the
|
||||
// forked agent and advance the cursor past this range so the next
|
||||
// extraction only considers messages after the main agent's write.
|
||||
if (hasMemoryWritesSince(messages, lastMemoryMessageUuid)) {
|
||||
logForDebugging(
|
||||
'[extractMemories] skipping — conversation already wrote to memory files',
|
||||
)
|
||||
const lastMessage = messages.at(-1)
|
||||
if (lastMessage?.uuid) {
|
||||
lastMemoryMessageUuid = lastMessage.uuid
|
||||
}
|
||||
logEvent('tengu_extract_memories_skipped_direct_write', {
|
||||
message_count: newMessageCount,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
const teamMemoryEnabled = feature('TEAMMEM')
|
||||
? teamMemPaths!.isTeamMemoryEnabled()
|
||||
: false
|
||||
|
||||
const skipIndex = getFeatureValue_CACHED_MAY_BE_STALE(
|
||||
'tengu_moth_copse',
|
||||
false,
|
||||
)
|
||||
|
||||
const canUseTool = createAutoMemCanUseTool(memoryDir)
|
||||
const cacheSafeParams = createCacheSafeParams(context)
|
||||
|
||||
// Only run extraction every N eligible turns (tengu_bramble_lintel, default 1).
|
||||
// Trailing extractions (from stashed contexts) skip this check since they
|
||||
// process already-committed work that should not be throttled.
|
||||
if (!isTrailingRun) {
|
||||
turnsSinceLastExtraction++
|
||||
if (
|
||||
turnsSinceLastExtraction <
|
||||
(getFeatureValue_CACHED_MAY_BE_STALE('tengu_bramble_lintel', null) ?? 1)
|
||||
) {
|
||||
return
|
||||
}
|
||||
}
|
||||
turnsSinceLastExtraction = 0
|
||||
|
||||
inProgress = true
|
||||
const startTime = Date.now()
|
||||
try {
|
||||
logForDebugging(
|
||||
`[extractMemories] starting — ${newMessageCount} new messages, memoryDir=${memoryDir}`,
|
||||
)
|
||||
|
||||
// Pre-inject the memory directory manifest so the agent doesn't spend
|
||||
// a turn on `ls`. Reuses findRelevantMemories' frontmatter scan.
|
||||
// Placed after the throttle gate so skipped turns don't pay the scan cost.
|
||||
const existingMemories = formatMemoryManifest(
|
||||
await scanMemoryFiles(memoryDir, createAbortController().signal),
|
||||
)
|
||||
|
||||
const userPrompt =
|
||||
feature('TEAMMEM') && teamMemoryEnabled
|
||||
? buildExtractCombinedPrompt(
|
||||
newMessageCount,
|
||||
existingMemories,
|
||||
skipIndex,
|
||||
)
|
||||
: buildExtractAutoOnlyPrompt(
|
||||
newMessageCount,
|
||||
existingMemories,
|
||||
skipIndex,
|
||||
)
|
||||
|
||||
const result = await runForkedAgent({
|
||||
promptMessages: [createUserMessage({ content: userPrompt })],
|
||||
cacheSafeParams,
|
||||
canUseTool,
|
||||
querySource: 'extract_memories',
|
||||
forkLabel: 'extract_memories',
|
||||
// The extractMemories subagent does not need to record to transcript.
|
||||
// Doing so can create race conditions with the main thread.
|
||||
skipTranscript: true,
|
||||
// Well-behaved extractions complete in 2-4 turns (read → write).
|
||||
// A hard cap prevents verification rabbit-holes from burning turns.
|
||||
maxTurns: 5,
|
||||
})
|
||||
|
||||
// Advance the cursor only after a successful run. If the agent errors
|
||||
// out (caught below), the cursor stays put so those messages are
|
||||
// reconsidered on the next extraction.
|
||||
const lastMessage = messages.at(-1)
|
||||
if (lastMessage?.uuid) {
|
||||
lastMemoryMessageUuid = lastMessage.uuid
|
||||
}
|
||||
|
||||
const writtenPaths = extractWrittenPaths(result.messages)
|
||||
const turnCount = count(result.messages, m => m.type === 'assistant')
|
||||
|
||||
const totalInput =
|
||||
result.totalUsage.input_tokens +
|
||||
result.totalUsage.cache_creation_input_tokens +
|
||||
result.totalUsage.cache_read_input_tokens
|
||||
const hitPct =
|
||||
totalInput > 0
|
||||
? (
|
||||
(result.totalUsage.cache_read_input_tokens / totalInput) *
|
||||
100
|
||||
).toFixed(1)
|
||||
: '0.0'
|
||||
logForDebugging(
|
||||
`[extractMemories] finished — ${writtenPaths.length} files written, cache: read=${result.totalUsage.cache_read_input_tokens} create=${result.totalUsage.cache_creation_input_tokens} input=${result.totalUsage.input_tokens} (${hitPct}% hit)`,
|
||||
)
|
||||
|
||||
if (writtenPaths.length > 0) {
|
||||
logForDebugging(
|
||||
`[extractMemories] memories saved: ${writtenPaths.join(', ')}`,
|
||||
)
|
||||
} else {
|
||||
logForDebugging('[extractMemories] no memories saved this run')
|
||||
}
|
||||
|
||||
// Index file updates are mechanical — the agent touches MEMORY.md to add
|
||||
// a topic link, but the user-visible "memory" is the topic file itself.
|
||||
const memoryPaths = writtenPaths.filter(
|
||||
p => basename(p) !== ENTRYPOINT_NAME,
|
||||
)
|
||||
const teamCount = feature('TEAMMEM')
|
||||
? count(memoryPaths, teamMemPaths!.isTeamMemPath)
|
||||
: 0
|
||||
|
||||
// Log extraction event with usage from the forked agent
|
||||
logEvent('tengu_extract_memories_extraction', {
|
||||
input_tokens: result.totalUsage.input_tokens,
|
||||
output_tokens: result.totalUsage.output_tokens,
|
||||
cache_read_input_tokens: result.totalUsage.cache_read_input_tokens,
|
||||
cache_creation_input_tokens:
|
||||
result.totalUsage.cache_creation_input_tokens,
|
||||
message_count: newMessageCount,
|
||||
turn_count: turnCount,
|
||||
files_written: writtenPaths.length,
|
||||
memories_saved: memoryPaths.length,
|
||||
team_memories_saved: teamCount,
|
||||
duration_ms: Date.now() - startTime,
|
||||
})
|
||||
|
||||
logForDebugging(
|
||||
`[extractMemories] writtenPaths=${writtenPaths.length} memoryPaths=${memoryPaths.length} appendSystemMessage defined=${appendSystemMessage != null}`,
|
||||
)
|
||||
if (memoryPaths.length > 0) {
|
||||
const msg = createMemorySavedMessage(memoryPaths)
|
||||
if (feature('TEAMMEM')) {
|
||||
msg.teamCount = teamCount
|
||||
}
|
||||
appendSystemMessage?.(msg)
|
||||
}
|
||||
} catch (error) {
|
||||
// Extraction is best-effort — log but don't notify on error
|
||||
logForDebugging(`[extractMemories] error: ${error}`)
|
||||
logEvent('tengu_extract_memories_error', {
|
||||
duration_ms: Date.now() - startTime,
|
||||
})
|
||||
} finally {
|
||||
inProgress = false
|
||||
|
||||
// If a call arrived while we were running, run a trailing extraction
|
||||
// with the latest stashed context. The trailing run will compute its
|
||||
// newMessageCount relative to the cursor we just advanced — so it only
|
||||
// picks up messages added between the two calls, not the full history.
|
||||
const trailing = pendingContext
|
||||
pendingContext = undefined
|
||||
if (trailing) {
|
||||
logForDebugging(
|
||||
'[extractMemories] running trailing extraction for stashed context',
|
||||
)
|
||||
await runExtraction({
|
||||
context: trailing.context,
|
||||
appendSystemMessage: trailing.appendSystemMessage,
|
||||
isTrailingRun: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Public entry point (captured by extractor) ---
|
||||
|
||||
async function executeExtractMemoriesImpl(
|
||||
context: REPLHookContext,
|
||||
appendSystemMessage?: AppendSystemMessageFn,
|
||||
): Promise<void> {
|
||||
// Only run for the main agent, not subagents
|
||||
if (context.toolUseContext.agentId) {
|
||||
return
|
||||
}
|
||||
|
||||
if (!getFeatureValue_CACHED_MAY_BE_STALE('tengu_passport_quail', false)) {
|
||||
if (process.env.USER_TYPE === 'ant' && !hasLoggedGateFailure) {
|
||||
hasLoggedGateFailure = true
|
||||
logEvent('tengu_extract_memories_gate_disabled', {})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Check auto-memory is enabled
|
||||
if (!isAutoMemoryEnabled()) {
|
||||
return
|
||||
}
|
||||
|
||||
// Skip in remote mode
|
||||
if (getIsRemoteMode()) {
|
||||
return
|
||||
}
|
||||
|
||||
// If an extraction is already in progress, stash this context for a
|
||||
// trailing run (overwrites any previously stashed context — only the
|
||||
// latest matters since it has the most messages).
|
||||
if (inProgress) {
|
||||
logForDebugging(
|
||||
'[extractMemories] extraction in progress — stashing for trailing run',
|
||||
)
|
||||
logEvent('tengu_extract_memories_coalesced', {})
|
||||
pendingContext = { context, appendSystemMessage }
|
||||
return
|
||||
}
|
||||
|
||||
await runExtraction({ context, appendSystemMessage })
|
||||
}
|
||||
|
||||
extractor = async (context, appendSystemMessage) => {
|
||||
const p = executeExtractMemoriesImpl(context, appendSystemMessage)
|
||||
inFlightExtractions.add(p)
|
||||
try {
|
||||
await p
|
||||
} finally {
|
||||
inFlightExtractions.delete(p)
|
||||
}
|
||||
}
|
||||
|
||||
drainer = async (timeoutMs = 60_000) => {
|
||||
if (inFlightExtractions.size === 0) return
|
||||
await Promise.race([
|
||||
Promise.all(inFlightExtractions).catch(() => {}),
|
||||
// eslint-disable-next-line no-restricted-syntax -- sleep() has no .unref(); timer must not block exit
|
||||
new Promise<void>(r => setTimeout(r, timeoutMs).unref()),
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Public API
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Run memory extraction at the end of a query loop.
|
||||
* Called fire-and-forget from handleStopHooks, alongside prompt suggestion/coaching.
|
||||
* No-ops until initExtractMemories() has been called.
|
||||
*/
|
||||
export async function executeExtractMemories(
|
||||
context: REPLHookContext,
|
||||
appendSystemMessage?: AppendSystemMessageFn,
|
||||
): Promise<void> {
|
||||
await extractor?.(context, appendSystemMessage)
|
||||
}
|
||||
|
||||
/**
|
||||
* Awaits all in-flight extractions (including trailing stashed runs) with a
|
||||
* soft timeout. Called by print.ts after the response is flushed but before
|
||||
* gracefulShutdownSync, so the forked agent completes before the 5s shutdown
|
||||
* failsafe kills it. No-op until initExtractMemories() has been called.
|
||||
*/
|
||||
export async function drainPendingExtraction(
|
||||
timeoutMs?: number,
|
||||
): Promise<void> {
|
||||
await drainer(timeoutMs)
|
||||
}
|
||||
154
src/services/extractMemories/prompts.ts
Normal file
154
src/services/extractMemories/prompts.ts
Normal file
@@ -0,0 +1,154 @@
|
||||
/**
|
||||
* Prompt templates for the background memory extraction agent.
|
||||
*
|
||||
* The extraction agent runs as a perfect fork of the main conversation — same
|
||||
* system prompt, same message prefix. The main agent's system prompt always
|
||||
* has full save instructions; when the main agent writes memories itself,
|
||||
* extractMemories.ts skips that turn (hasMemoryWritesSince). This prompt
|
||||
* fires only when the main agent didn't write, so the save-criteria here
|
||||
* overlap the system prompt's harmlessly.
|
||||
*/
|
||||
|
||||
import { feature } from 'bun:bundle'
|
||||
import {
|
||||
MEMORY_FRONTMATTER_EXAMPLE,
|
||||
TYPES_SECTION_COMBINED,
|
||||
TYPES_SECTION_INDIVIDUAL,
|
||||
WHAT_NOT_TO_SAVE_SECTION,
|
||||
} from '../../memdir/memoryTypes.js'
|
||||
import { BASH_TOOL_NAME } from '../../tools/BashTool/toolName.js'
|
||||
import { FILE_EDIT_TOOL_NAME } from '../../tools/FileEditTool/constants.js'
|
||||
import { FILE_READ_TOOL_NAME } from '../../tools/FileReadTool/prompt.js'
|
||||
import { FILE_WRITE_TOOL_NAME } from '../../tools/FileWriteTool/prompt.js'
|
||||
import { GLOB_TOOL_NAME } from '../../tools/GlobTool/prompt.js'
|
||||
import { GREP_TOOL_NAME } from '../../tools/GrepTool/prompt.js'
|
||||
|
||||
/**
|
||||
* Shared opener for both extract-prompt variants.
|
||||
*/
|
||||
function opener(newMessageCount: number, existingMemories: string): string {
|
||||
const manifest =
|
||||
existingMemories.length > 0
|
||||
? `\n\n## Existing memory files\n\n${existingMemories}\n\nCheck this list before writing — update an existing file rather than creating a duplicate.`
|
||||
: ''
|
||||
return [
|
||||
`You are now acting as the memory extraction subagent. Analyze the most recent ~${newMessageCount} messages above and use them to update your persistent memory systems.`,
|
||||
'',
|
||||
`Available tools: ${FILE_READ_TOOL_NAME}, ${GREP_TOOL_NAME}, ${GLOB_TOOL_NAME}, read-only ${BASH_TOOL_NAME} (ls/find/cat/stat/wc/head/tail and similar), and ${FILE_EDIT_TOOL_NAME}/${FILE_WRITE_TOOL_NAME} for paths inside the memory directory only. ${BASH_TOOL_NAME} rm is not permitted. All other tools — MCP, Agent, write-capable ${BASH_TOOL_NAME}, etc — will be denied.`,
|
||||
'',
|
||||
`You have a limited turn budget. ${FILE_EDIT_TOOL_NAME} requires a prior ${FILE_READ_TOOL_NAME} of the same file, so the efficient strategy is: turn 1 — issue all ${FILE_READ_TOOL_NAME} calls in parallel for every file you might update; turn 2 — issue all ${FILE_WRITE_TOOL_NAME}/${FILE_EDIT_TOOL_NAME} calls in parallel. Do not interleave reads and writes across multiple turns.`,
|
||||
'',
|
||||
`You MUST only use content from the last ~${newMessageCount} messages to update your persistent memories. Do not waste any turns attempting to investigate or verify that content further — no grepping source files, no reading code to confirm a pattern exists, no git commands.` +
|
||||
manifest,
|
||||
].join('\n')
|
||||
}
|
||||
|
||||
/**
|
||||
* Build the extraction prompt for auto-only memory (no team memory).
|
||||
* Four-type taxonomy, no scope guidance (single directory).
|
||||
*/
|
||||
export function buildExtractAutoOnlyPrompt(
|
||||
newMessageCount: number,
|
||||
existingMemories: string,
|
||||
skipIndex = false,
|
||||
): string {
|
||||
const howToSave = skipIndex
|
||||
? [
|
||||
'## How to save memories',
|
||||
'',
|
||||
'Write each memory to its own file (e.g., `user_role.md`, `feedback_testing.md`) using this frontmatter format:',
|
||||
'',
|
||||
...MEMORY_FRONTMATTER_EXAMPLE,
|
||||
'',
|
||||
'- Organize memory semantically by topic, not chronologically',
|
||||
'- Update or remove memories that turn out to be wrong or outdated',
|
||||
'- Do not write duplicate memories. First check if there is an existing memory you can update before writing a new one.',
|
||||
]
|
||||
: [
|
||||
'## How to save memories',
|
||||
'',
|
||||
'Saving a memory is a two-step process:',
|
||||
'',
|
||||
'**Step 1** — write the memory to its own file (e.g., `user_role.md`, `feedback_testing.md`) using this frontmatter format:',
|
||||
'',
|
||||
...MEMORY_FRONTMATTER_EXAMPLE,
|
||||
'',
|
||||
'**Step 2** — add a pointer to that file in `MEMORY.md`. `MEMORY.md` is an index, not a memory — each entry should be one line, under ~150 characters: `- [Title](file.md) — one-line hook`. It has no frontmatter. Never write memory content directly into `MEMORY.md`.',
|
||||
'',
|
||||
'- `MEMORY.md` is always loaded into your system prompt — lines after 200 will be truncated, so keep the index concise',
|
||||
'- Organize memory semantically by topic, not chronologically',
|
||||
'- Update or remove memories that turn out to be wrong or outdated',
|
||||
'- Do not write duplicate memories. First check if there is an existing memory you can update before writing a new one.',
|
||||
]
|
||||
|
||||
return [
|
||||
opener(newMessageCount, existingMemories),
|
||||
'',
|
||||
'If the user explicitly asks you to remember something, save it immediately as whichever type fits best. If they ask you to forget something, find and remove the relevant entry.',
|
||||
'',
|
||||
...TYPES_SECTION_INDIVIDUAL,
|
||||
...WHAT_NOT_TO_SAVE_SECTION,
|
||||
'',
|
||||
...howToSave,
|
||||
].join('\n')
|
||||
}
|
||||
|
||||
/**
|
||||
* Build the extraction prompt for combined auto + team memory.
|
||||
* Four-type taxonomy with per-type <scope> guidance (directory choice
|
||||
* is baked into each type block, no separate routing section needed).
|
||||
*/
|
||||
export function buildExtractCombinedPrompt(
|
||||
newMessageCount: number,
|
||||
existingMemories: string,
|
||||
skipIndex = false,
|
||||
): string {
|
||||
if (!feature('TEAMMEM')) {
|
||||
return buildExtractAutoOnlyPrompt(
|
||||
newMessageCount,
|
||||
existingMemories,
|
||||
skipIndex,
|
||||
)
|
||||
}
|
||||
|
||||
const howToSave = skipIndex
|
||||
? [
|
||||
'## How to save memories',
|
||||
'',
|
||||
"Write each memory to its own file in the chosen directory (private or team, per the type's scope guidance) using this frontmatter format:",
|
||||
'',
|
||||
...MEMORY_FRONTMATTER_EXAMPLE,
|
||||
'',
|
||||
'- Organize memory semantically by topic, not chronologically',
|
||||
'- Update or remove memories that turn out to be wrong or outdated',
|
||||
'- Do not write duplicate memories. First check if there is an existing memory you can update before writing a new one.',
|
||||
]
|
||||
: [
|
||||
'## How to save memories',
|
||||
'',
|
||||
'Saving a memory is a two-step process:',
|
||||
'',
|
||||
"**Step 1** — write the memory to its own file in the chosen directory (private or team, per the type's scope guidance) using this frontmatter format:",
|
||||
'',
|
||||
...MEMORY_FRONTMATTER_EXAMPLE,
|
||||
'',
|
||||
"**Step 2** — add a pointer to that file in the same directory's `MEMORY.md`. Each directory (private and team) has its own `MEMORY.md` index — each entry should be one line, under ~150 characters: `- [Title](file.md) — one-line hook`. They have no frontmatter. Never write memory content directly into a `MEMORY.md`.",
|
||||
'',
|
||||
'- Both `MEMORY.md` indexes are loaded into your system prompt — lines after 200 will be truncated, so keep them concise',
|
||||
'- Organize memory semantically by topic, not chronologically',
|
||||
'- Update or remove memories that turn out to be wrong or outdated',
|
||||
'- Do not write duplicate memories. First check if there is an existing memory you can update before writing a new one.',
|
||||
]
|
||||
|
||||
return [
|
||||
opener(newMessageCount, existingMemories),
|
||||
'',
|
||||
'If the user explicitly asks you to remember something, save it immediately as whichever type fits best. If they ask you to forget something, find and remove the relevant entry.',
|
||||
'',
|
||||
...TYPES_SECTION_COMBINED,
|
||||
...WHAT_NOT_TO_SAVE_SECTION,
|
||||
'- You MUST avoid saving sensitive data within shared team memories. For example, never save API keys or user credentials.',
|
||||
'',
|
||||
...howToSave,
|
||||
].join('\n')
|
||||
}
|
||||
90
src/services/internalLogging.ts
Normal file
90
src/services/internalLogging.ts
Normal file
@@ -0,0 +1,90 @@
|
||||
import { readFile } from 'fs/promises'
|
||||
import memoize from 'lodash-es/memoize.js'
|
||||
import type { ToolPermissionContext } from '../Tool.js'
|
||||
import { jsonStringify } from '../utils/slowOperations.js'
|
||||
import {
|
||||
type AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
logEvent,
|
||||
} from './analytics/index.js'
|
||||
|
||||
/**
|
||||
* Get the current Kubernetes namespace:
|
||||
* Returns null on laptops/local development,
|
||||
* "default" for devboxes in default namespace,
|
||||
* "ts" for devboxes in ts namespace,
|
||||
* ...
|
||||
*/
|
||||
const getKubernetesNamespace = memoize(async (): Promise<string | null> => {
|
||||
if (process.env.USER_TYPE !== 'ant') {
|
||||
return null
|
||||
}
|
||||
const namespacePath =
|
||||
'/var/run/secrets/kubernetes.io/serviceaccount/namespace'
|
||||
const namespaceNotFound = 'namespace not found'
|
||||
try {
|
||||
const content = await readFile(namespacePath, { encoding: 'utf8' })
|
||||
return content.trim()
|
||||
} catch {
|
||||
return namespaceNotFound
|
||||
}
|
||||
})
|
||||
|
||||
/**
|
||||
* Get the OCI container ID from within a running container
|
||||
*/
|
||||
export const getContainerId = memoize(async (): Promise<string | null> => {
|
||||
if (process.env.USER_TYPE !== 'ant') {
|
||||
return null
|
||||
}
|
||||
const containerIdPath = '/proc/self/mountinfo'
|
||||
const containerIdNotFound = 'container ID not found'
|
||||
const containerIdNotFoundInMountinfo = 'container ID not found in mountinfo'
|
||||
try {
|
||||
const mountinfo = (
|
||||
await readFile(containerIdPath, { encoding: 'utf8' })
|
||||
).trim()
|
||||
|
||||
// Pattern to match both Docker and containerd/CRI-O container IDs
|
||||
// Docker: /docker/containers/[64-char-hex]
|
||||
// Containerd: /sandboxes/[64-char-hex]
|
||||
const containerIdPattern =
|
||||
/(?:\/docker\/containers\/|\/sandboxes\/)([0-9a-f]{64})/
|
||||
|
||||
const lines = mountinfo.split('\n')
|
||||
|
||||
for (const line of lines) {
|
||||
const match = line.match(containerIdPattern)
|
||||
if (match && match[1]) {
|
||||
return match[1]
|
||||
}
|
||||
}
|
||||
|
||||
return containerIdNotFoundInMountinfo
|
||||
} catch {
|
||||
return containerIdNotFound
|
||||
}
|
||||
})
|
||||
|
||||
/**
|
||||
* Logs an event with the current namespace and tool permission context
|
||||
*/
|
||||
export async function logPermissionContextForAnts(
|
||||
toolPermissionContext: ToolPermissionContext | null,
|
||||
moment: 'summary' | 'initialization',
|
||||
): Promise<void> {
|
||||
if (process.env.USER_TYPE !== 'ant') {
|
||||
return
|
||||
}
|
||||
|
||||
void logEvent('tengu_internal_record_permission_context', {
|
||||
moment:
|
||||
moment as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
namespace:
|
||||
(await getKubernetesNamespace()) as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
toolPermissionContext: jsonStringify(
|
||||
toolPermissionContext,
|
||||
) as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
containerId:
|
||||
(await getContainerId()) as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
}
|
||||
447
src/services/lsp/LSPClient.ts
Normal file
447
src/services/lsp/LSPClient.ts
Normal file
@@ -0,0 +1,447 @@
|
||||
import { type ChildProcess, spawn } from 'child_process'
|
||||
import {
|
||||
createMessageConnection,
|
||||
type MessageConnection,
|
||||
StreamMessageReader,
|
||||
StreamMessageWriter,
|
||||
Trace,
|
||||
} from 'vscode-jsonrpc/node.js'
|
||||
import type {
|
||||
InitializeParams,
|
||||
InitializeResult,
|
||||
ServerCapabilities,
|
||||
} from 'vscode-languageserver-protocol'
|
||||
import { logForDebugging } from '../../utils/debug.js'
|
||||
import { errorMessage } from '../../utils/errors.js'
|
||||
import { logError } from '../../utils/log.js'
|
||||
import { subprocessEnv } from '../../utils/subprocessEnv.js'
|
||||
/**
|
||||
* LSP client interface.
|
||||
*/
|
||||
export type LSPClient = {
|
||||
readonly capabilities: ServerCapabilities | undefined
|
||||
readonly isInitialized: boolean
|
||||
start: (
|
||||
command: string,
|
||||
args: string[],
|
||||
options?: {
|
||||
env?: Record<string, string>
|
||||
cwd?: string
|
||||
},
|
||||
) => Promise<void>
|
||||
initialize: (params: InitializeParams) => Promise<InitializeResult>
|
||||
sendRequest: <TResult>(method: string, params: unknown) => Promise<TResult>
|
||||
sendNotification: (method: string, params: unknown) => Promise<void>
|
||||
onNotification: (method: string, handler: (params: unknown) => void) => void
|
||||
onRequest: <TParams, TResult>(
|
||||
method: string,
|
||||
handler: (params: TParams) => TResult | Promise<TResult>,
|
||||
) => void
|
||||
stop: () => Promise<void>
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an LSP client wrapper using vscode-jsonrpc.
|
||||
* Manages communication with an LSP server process via stdio.
|
||||
*
|
||||
* @param onCrash - Called when the server process exits unexpectedly (non-zero
|
||||
* exit code during operation, not during intentional stop). Allows the owner
|
||||
* to propagate crash state so the server can be restarted on next use.
|
||||
*/
|
||||
export function createLSPClient(
|
||||
serverName: string,
|
||||
onCrash?: (error: Error) => void,
|
||||
): LSPClient {
|
||||
// State variables in closure
|
||||
let process: ChildProcess | undefined
|
||||
let connection: MessageConnection | undefined
|
||||
let capabilities: ServerCapabilities | undefined
|
||||
let isInitialized = false
|
||||
let startFailed = false
|
||||
let startError: Error | undefined
|
||||
let isStopping = false // Track intentional shutdown to avoid spurious error logging
|
||||
// Queue handlers registered before connection ready (lazy initialization support)
|
||||
const pendingHandlers: Array<{
|
||||
method: string
|
||||
handler: (params: unknown) => void
|
||||
}> = []
|
||||
const pendingRequestHandlers: Array<{
|
||||
method: string
|
||||
handler: (params: unknown) => unknown | Promise<unknown>
|
||||
}> = []
|
||||
|
||||
function checkStartFailed(): void {
|
||||
if (startFailed) {
|
||||
throw startError || new Error(`LSP server ${serverName} failed to start`)
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
get capabilities(): ServerCapabilities | undefined {
|
||||
return capabilities
|
||||
},
|
||||
|
||||
get isInitialized(): boolean {
|
||||
return isInitialized
|
||||
},
|
||||
|
||||
async start(
|
||||
command: string,
|
||||
args: string[],
|
||||
options?: {
|
||||
env?: Record<string, string>
|
||||
cwd?: string
|
||||
},
|
||||
): Promise<void> {
|
||||
try {
|
||||
// 1. Spawn LSP server process
|
||||
process = spawn(command, args, {
|
||||
stdio: ['pipe', 'pipe', 'pipe'],
|
||||
env: { ...subprocessEnv(), ...options?.env },
|
||||
cwd: options?.cwd,
|
||||
// Prevent visible console window on Windows (no-op on other platforms)
|
||||
windowsHide: true,
|
||||
})
|
||||
|
||||
if (!process.stdout || !process.stdin) {
|
||||
throw new Error('LSP server process stdio not available')
|
||||
}
|
||||
|
||||
// 1.5. Wait for process to successfully spawn before using streams
|
||||
// This is CRITICAL: spawn() returns immediately, but the 'error' event
|
||||
// (e.g., ENOENT for command not found) fires asynchronously.
|
||||
// If we use the streams before confirming spawn succeeded, we get
|
||||
// unhandled promise rejections when writes fail on invalid streams.
|
||||
const spawnedProcess = process // Capture for closure
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
const onSpawn = (): void => {
|
||||
cleanup()
|
||||
resolve()
|
||||
}
|
||||
const onError = (error: Error): void => {
|
||||
cleanup()
|
||||
reject(error)
|
||||
}
|
||||
const cleanup = (): void => {
|
||||
spawnedProcess.removeListener('spawn', onSpawn)
|
||||
spawnedProcess.removeListener('error', onError)
|
||||
}
|
||||
spawnedProcess.once('spawn', onSpawn)
|
||||
spawnedProcess.once('error', onError)
|
||||
})
|
||||
|
||||
// Capture stderr for server diagnostics and errors
|
||||
if (process.stderr) {
|
||||
process.stderr.on('data', (data: Buffer) => {
|
||||
const output = data.toString().trim()
|
||||
if (output) {
|
||||
logForDebugging(`[LSP SERVER ${serverName}] ${output}`)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Handle process errors (after successful spawn, e.g., crash during operation)
|
||||
process.on('error', error => {
|
||||
if (!isStopping) {
|
||||
startFailed = true
|
||||
startError = error
|
||||
logError(
|
||||
new Error(
|
||||
`LSP server ${serverName} failed to start: ${error.message}`,
|
||||
),
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
process.on('exit', (code, _signal) => {
|
||||
if (code !== 0 && code !== null && !isStopping) {
|
||||
isInitialized = false
|
||||
startFailed = false
|
||||
startError = undefined
|
||||
const crashError = new Error(
|
||||
`LSP server ${serverName} crashed with exit code ${code}`,
|
||||
)
|
||||
logError(crashError)
|
||||
onCrash?.(crashError)
|
||||
}
|
||||
})
|
||||
|
||||
// Handle stdin stream errors to prevent unhandled promise rejections
|
||||
// when the LSP server process exits before we finish writing
|
||||
process.stdin.on('error', (error: Error) => {
|
||||
if (!isStopping) {
|
||||
logForDebugging(
|
||||
`LSP server ${serverName} stdin error: ${error.message}`,
|
||||
)
|
||||
}
|
||||
// Error is logged but not thrown - the connection error handler will catch this
|
||||
})
|
||||
|
||||
// 2. Create JSON-RPC connection
|
||||
const reader = new StreamMessageReader(process.stdout)
|
||||
const writer = new StreamMessageWriter(process.stdin)
|
||||
connection = createMessageConnection(reader, writer)
|
||||
|
||||
// 2.5. Register error/close handlers BEFORE listen() to catch all errors
|
||||
// This prevents unhandled promise rejections when the server crashes or closes unexpectedly
|
||||
connection.onError(([error, _message, _code]) => {
|
||||
// Only log if not intentionally stopping (avoid spurious errors during shutdown)
|
||||
if (!isStopping) {
|
||||
startFailed = true
|
||||
startError = error
|
||||
logError(
|
||||
new Error(
|
||||
`LSP server ${serverName} connection error: ${error.message}`,
|
||||
),
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
connection.onClose(() => {
|
||||
// Only treat as error if not intentionally stopping
|
||||
if (!isStopping) {
|
||||
isInitialized = false
|
||||
// Don't set startFailed here - the connection may close after graceful shutdown
|
||||
logForDebugging(`LSP server ${serverName} connection closed`)
|
||||
}
|
||||
})
|
||||
|
||||
// 3. Start listening for messages
|
||||
connection.listen()
|
||||
|
||||
// 3.5. Enable protocol tracing for debugging
|
||||
// Note: trace() sends a $/setTrace notification which can fail if the server
|
||||
// process has already exited. We catch and log the error rather than letting
|
||||
// it become an unhandled promise rejection.
|
||||
connection
|
||||
.trace(Trace.Verbose, {
|
||||
log: (message: string) => {
|
||||
logForDebugging(`[LSP PROTOCOL ${serverName}] ${message}`)
|
||||
},
|
||||
})
|
||||
.catch((error: Error) => {
|
||||
logForDebugging(
|
||||
`Failed to enable tracing for ${serverName}: ${error.message}`,
|
||||
)
|
||||
})
|
||||
|
||||
// 4. Apply any queued notification handlers
|
||||
for (const { method, handler } of pendingHandlers) {
|
||||
connection.onNotification(method, handler)
|
||||
logForDebugging(
|
||||
`Applied queued notification handler for ${serverName}.${method}`,
|
||||
)
|
||||
}
|
||||
pendingHandlers.length = 0 // Clear the queue
|
||||
|
||||
// 5. Apply any queued request handlers
|
||||
for (const { method, handler } of pendingRequestHandlers) {
|
||||
connection.onRequest(method, handler)
|
||||
logForDebugging(
|
||||
`Applied queued request handler for ${serverName}.${method}`,
|
||||
)
|
||||
}
|
||||
pendingRequestHandlers.length = 0 // Clear the queue
|
||||
|
||||
logForDebugging(`LSP client started for ${serverName}`)
|
||||
} catch (error) {
|
||||
const err = error as Error
|
||||
logError(
|
||||
new Error(`LSP server ${serverName} failed to start: ${err.message}`),
|
||||
)
|
||||
throw error
|
||||
}
|
||||
},
|
||||
|
||||
async initialize(params: InitializeParams): Promise<InitializeResult> {
|
||||
if (!connection) {
|
||||
throw new Error('LSP client not started')
|
||||
}
|
||||
|
||||
checkStartFailed()
|
||||
|
||||
try {
|
||||
const result: InitializeResult = await connection.sendRequest(
|
||||
'initialize',
|
||||
params,
|
||||
)
|
||||
|
||||
capabilities = result.capabilities
|
||||
|
||||
// Send initialized notification
|
||||
await connection.sendNotification('initialized', {})
|
||||
|
||||
isInitialized = true
|
||||
logForDebugging(`LSP server ${serverName} initialized`)
|
||||
|
||||
return result
|
||||
} catch (error) {
|
||||
const err = error as Error
|
||||
logError(
|
||||
new Error(
|
||||
`LSP server ${serverName} initialize failed: ${err.message}`,
|
||||
),
|
||||
)
|
||||
throw error
|
||||
}
|
||||
},
|
||||
|
||||
async sendRequest<TResult>(
|
||||
method: string,
|
||||
params: unknown,
|
||||
): Promise<TResult> {
|
||||
if (!connection) {
|
||||
throw new Error('LSP client not started')
|
||||
}
|
||||
|
||||
checkStartFailed()
|
||||
|
||||
if (!isInitialized) {
|
||||
throw new Error('LSP server not initialized')
|
||||
}
|
||||
|
||||
try {
|
||||
return await connection.sendRequest(method, params)
|
||||
} catch (error) {
|
||||
const err = error as Error
|
||||
logError(
|
||||
new Error(
|
||||
`LSP server ${serverName} request ${method} failed: ${err.message}`,
|
||||
),
|
||||
)
|
||||
throw error
|
||||
}
|
||||
},
|
||||
|
||||
async sendNotification(method: string, params: unknown): Promise<void> {
|
||||
if (!connection) {
|
||||
throw new Error('LSP client not started')
|
||||
}
|
||||
|
||||
checkStartFailed()
|
||||
|
||||
try {
|
||||
await connection.sendNotification(method, params)
|
||||
} catch (error) {
|
||||
const err = error as Error
|
||||
logError(
|
||||
new Error(
|
||||
`LSP server ${serverName} notification ${method} failed: ${err.message}`,
|
||||
),
|
||||
)
|
||||
// Don't re-throw for notifications - they're fire-and-forget
|
||||
logForDebugging(`Notification ${method} failed but continuing`)
|
||||
}
|
||||
},
|
||||
|
||||
onNotification(method: string, handler: (params: unknown) => void): void {
|
||||
if (!connection) {
|
||||
// Queue handler for application when connection is ready (lazy initialization)
|
||||
pendingHandlers.push({ method, handler })
|
||||
logForDebugging(
|
||||
`Queued notification handler for ${serverName}.${method} (connection not ready)`,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
checkStartFailed()
|
||||
|
||||
connection.onNotification(method, handler)
|
||||
},
|
||||
|
||||
onRequest<TParams, TResult>(
|
||||
method: string,
|
||||
handler: (params: TParams) => TResult | Promise<TResult>,
|
||||
): void {
|
||||
if (!connection) {
|
||||
// Queue handler for application when connection is ready (lazy initialization)
|
||||
pendingRequestHandlers.push({
|
||||
method,
|
||||
handler: handler as (params: unknown) => unknown | Promise<unknown>,
|
||||
})
|
||||
logForDebugging(
|
||||
`Queued request handler for ${serverName}.${method} (connection not ready)`,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
checkStartFailed()
|
||||
|
||||
connection.onRequest(method, handler)
|
||||
},
|
||||
|
||||
async stop(): Promise<void> {
|
||||
let shutdownError: Error | undefined
|
||||
|
||||
// Mark as stopping to prevent error handlers from logging spurious errors
|
||||
isStopping = true
|
||||
|
||||
try {
|
||||
if (connection) {
|
||||
// Try to send shutdown request and exit notification
|
||||
await connection.sendRequest('shutdown', {})
|
||||
await connection.sendNotification('exit', {})
|
||||
}
|
||||
} catch (error) {
|
||||
const err = error as Error
|
||||
logError(
|
||||
new Error(`LSP server ${serverName} stop failed: ${err.message}`),
|
||||
)
|
||||
shutdownError = err
|
||||
// Continue to cleanup despite shutdown failure
|
||||
} finally {
|
||||
// Always cleanup resources, even if shutdown/exit failed
|
||||
if (connection) {
|
||||
try {
|
||||
connection.dispose()
|
||||
} catch (error) {
|
||||
// Log but don't throw - disposal errors are less critical
|
||||
logForDebugging(
|
||||
`Connection disposal failed for ${serverName}: ${errorMessage(error)}`,
|
||||
)
|
||||
}
|
||||
connection = undefined
|
||||
}
|
||||
|
||||
if (process) {
|
||||
// Remove event listeners to prevent memory leaks
|
||||
process.removeAllListeners('error')
|
||||
process.removeAllListeners('exit')
|
||||
if (process.stdin) {
|
||||
process.stdin.removeAllListeners('error')
|
||||
}
|
||||
if (process.stderr) {
|
||||
process.stderr.removeAllListeners('data')
|
||||
}
|
||||
|
||||
try {
|
||||
process.kill()
|
||||
} catch (error) {
|
||||
// Process might already be dead, which is fine
|
||||
logForDebugging(
|
||||
`Process kill failed for ${serverName} (may already be dead): ${errorMessage(error)}`,
|
||||
)
|
||||
}
|
||||
process = undefined
|
||||
}
|
||||
|
||||
isInitialized = false
|
||||
capabilities = undefined
|
||||
isStopping = false // Reset for potential restart
|
||||
// Don't reset startFailed - preserve error state for diagnostics
|
||||
// startFailed and startError remain as-is
|
||||
if (shutdownError) {
|
||||
startFailed = true
|
||||
startError = shutdownError
|
||||
}
|
||||
|
||||
logForDebugging(`LSP client stopped for ${serverName}`)
|
||||
}
|
||||
|
||||
// Re-throw shutdown error after cleanup is complete
|
||||
if (shutdownError) {
|
||||
throw shutdownError
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
386
src/services/lsp/LSPDiagnosticRegistry.ts
Normal file
386
src/services/lsp/LSPDiagnosticRegistry.ts
Normal file
@@ -0,0 +1,386 @@
|
||||
import { randomUUID } from 'crypto'
|
||||
import { LRUCache } from 'lru-cache'
|
||||
import { logForDebugging } from '../../utils/debug.js'
|
||||
import { toError } from '../../utils/errors.js'
|
||||
import { logError } from '../../utils/log.js'
|
||||
import { jsonStringify } from '../../utils/slowOperations.js'
|
||||
import type { DiagnosticFile } from '../diagnosticTracking.js'
|
||||
|
||||
/**
|
||||
* Pending LSP diagnostic notification
|
||||
*/
|
||||
export type PendingLSPDiagnostic = {
|
||||
/** Server that sent the diagnostic */
|
||||
serverName: string
|
||||
/** Diagnostic files */
|
||||
files: DiagnosticFile[]
|
||||
/** When diagnostic was received */
|
||||
timestamp: number
|
||||
/** Whether attachment was already sent to conversation */
|
||||
attachmentSent: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* LSP Diagnostic Registry
|
||||
*
|
||||
* Stores LSP diagnostics received asynchronously from LSP servers via
|
||||
* textDocument/publishDiagnostics notifications. Follows the same pattern
|
||||
* as AsyncHookRegistry for consistent async attachment delivery.
|
||||
*
|
||||
* Pattern:
|
||||
* 1. LSP server sends publishDiagnostics notification
|
||||
* 2. registerPendingLSPDiagnostic() stores diagnostic
|
||||
* 3. checkForLSPDiagnostics() retrieves pending diagnostics
|
||||
* 4. getLSPDiagnosticAttachments() converts to Attachment[]
|
||||
* 5. getAttachments() delivers to conversation automatically
|
||||
*
|
||||
* Similar to AsyncHookRegistry but simpler since diagnostics arrive
|
||||
* synchronously (no need to accumulate output over time).
|
||||
*/
|
||||
|
||||
// Volume limiting constants
|
||||
const MAX_DIAGNOSTICS_PER_FILE = 10
|
||||
const MAX_TOTAL_DIAGNOSTICS = 30
|
||||
|
||||
// Max files to track for deduplication - prevents unbounded memory growth
|
||||
const MAX_DELIVERED_FILES = 500
|
||||
|
||||
// Global registry state
|
||||
const pendingDiagnostics = new Map<string, PendingLSPDiagnostic>()
|
||||
|
||||
// Cross-turn deduplication: tracks diagnostics that have been delivered
|
||||
// Maps file URI to a set of diagnostic keys (hash of message+severity+range)
|
||||
// Using LRUCache to prevent unbounded growth in long sessions
|
||||
const deliveredDiagnostics = new LRUCache<string, Set<string>>({
|
||||
max: MAX_DELIVERED_FILES,
|
||||
})
|
||||
|
||||
/**
|
||||
* Register LSP diagnostics received from a server.
|
||||
* These will be delivered as attachments in the next query.
|
||||
*
|
||||
* @param serverName - Name of LSP server that sent diagnostics
|
||||
* @param files - Diagnostic files to deliver
|
||||
*/
|
||||
export function registerPendingLSPDiagnostic({
|
||||
serverName,
|
||||
files,
|
||||
}: {
|
||||
serverName: string
|
||||
files: DiagnosticFile[]
|
||||
}): void {
|
||||
// Use UUID for guaranteed uniqueness (handles rapid registrations)
|
||||
const diagnosticId = randomUUID()
|
||||
|
||||
logForDebugging(
|
||||
`LSP Diagnostics: Registering ${files.length} diagnostic file(s) from ${serverName} (ID: ${diagnosticId})`,
|
||||
)
|
||||
|
||||
pendingDiagnostics.set(diagnosticId, {
|
||||
serverName,
|
||||
files,
|
||||
timestamp: Date.now(),
|
||||
attachmentSent: false,
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Maps severity string to numeric value for sorting.
|
||||
* Error=1, Warning=2, Info=3, Hint=4
|
||||
*/
|
||||
function severityToNumber(severity: string | undefined): number {
|
||||
switch (severity) {
|
||||
case 'Error':
|
||||
return 1
|
||||
case 'Warning':
|
||||
return 2
|
||||
case 'Info':
|
||||
return 3
|
||||
case 'Hint':
|
||||
return 4
|
||||
default:
|
||||
return 4
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a unique key for a diagnostic based on its content.
|
||||
* Used for both within-batch and cross-turn deduplication.
|
||||
*/
|
||||
function createDiagnosticKey(diag: {
|
||||
message: string
|
||||
severity?: string
|
||||
range?: unknown
|
||||
source?: string
|
||||
code?: unknown
|
||||
}): string {
|
||||
return jsonStringify({
|
||||
message: diag.message,
|
||||
severity: diag.severity,
|
||||
range: diag.range,
|
||||
source: diag.source || null,
|
||||
code: diag.code || null,
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Deduplicates diagnostics by file URI and diagnostic content.
|
||||
* Also filters out diagnostics that were already delivered in previous turns.
|
||||
* Two diagnostics are considered duplicates if they have the same:
|
||||
* - File URI
|
||||
* - Range (start/end line and character)
|
||||
* - Message
|
||||
* - Severity
|
||||
* - Source and code (if present)
|
||||
*/
|
||||
function deduplicateDiagnosticFiles(
|
||||
allFiles: DiagnosticFile[],
|
||||
): DiagnosticFile[] {
|
||||
// Group diagnostics by file URI
|
||||
const fileMap = new Map<string, Set<string>>()
|
||||
const dedupedFiles: DiagnosticFile[] = []
|
||||
|
||||
for (const file of allFiles) {
|
||||
if (!fileMap.has(file.uri)) {
|
||||
fileMap.set(file.uri, new Set())
|
||||
dedupedFiles.push({ uri: file.uri, diagnostics: [] })
|
||||
}
|
||||
|
||||
const seenDiagnostics = fileMap.get(file.uri)!
|
||||
const dedupedFile = dedupedFiles.find(f => f.uri === file.uri)!
|
||||
|
||||
// Get previously delivered diagnostics for this file (for cross-turn dedup)
|
||||
const previouslyDelivered = deliveredDiagnostics.get(file.uri) || new Set()
|
||||
|
||||
for (const diag of file.diagnostics) {
|
||||
try {
|
||||
const key = createDiagnosticKey(diag)
|
||||
|
||||
// Skip if already seen in this batch OR already delivered in previous turns
|
||||
if (seenDiagnostics.has(key) || previouslyDelivered.has(key)) {
|
||||
continue
|
||||
}
|
||||
|
||||
seenDiagnostics.add(key)
|
||||
dedupedFile.diagnostics.push(diag)
|
||||
} catch (error: unknown) {
|
||||
const err = toError(error)
|
||||
const truncatedMessage =
|
||||
diag.message?.substring(0, 100) || '<no message>'
|
||||
logError(
|
||||
new Error(
|
||||
`Failed to deduplicate diagnostic in ${file.uri}: ${err.message}. ` +
|
||||
`Diagnostic message: ${truncatedMessage}`,
|
||||
),
|
||||
)
|
||||
// Include the diagnostic anyway to avoid losing information
|
||||
dedupedFile.diagnostics.push(diag)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Filter out files with no diagnostics after deduplication
|
||||
return dedupedFiles.filter(f => f.diagnostics.length > 0)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all pending LSP diagnostics that haven't been delivered yet.
|
||||
* Deduplicates diagnostics to prevent sending the same diagnostic multiple times.
|
||||
* Marks diagnostics as sent to prevent duplicate delivery.
|
||||
*
|
||||
* @returns Array of pending diagnostics ready for delivery (deduplicated)
|
||||
*/
|
||||
export function checkForLSPDiagnostics(): Array<{
|
||||
serverName: string
|
||||
files: DiagnosticFile[]
|
||||
}> {
|
||||
logForDebugging(
|
||||
`LSP Diagnostics: Checking registry - ${pendingDiagnostics.size} pending`,
|
||||
)
|
||||
|
||||
// Collect all diagnostic files from all pending notifications
|
||||
const allFiles: DiagnosticFile[] = []
|
||||
const serverNames = new Set<string>()
|
||||
const diagnosticsToMark: PendingLSPDiagnostic[] = []
|
||||
|
||||
for (const diagnostic of pendingDiagnostics.values()) {
|
||||
if (!diagnostic.attachmentSent) {
|
||||
allFiles.push(...diagnostic.files)
|
||||
serverNames.add(diagnostic.serverName)
|
||||
diagnosticsToMark.push(diagnostic)
|
||||
}
|
||||
}
|
||||
|
||||
if (allFiles.length === 0) {
|
||||
return []
|
||||
}
|
||||
|
||||
// Deduplicate diagnostics across all files
|
||||
let dedupedFiles: DiagnosticFile[]
|
||||
try {
|
||||
dedupedFiles = deduplicateDiagnosticFiles(allFiles)
|
||||
} catch (error: unknown) {
|
||||
const err = toError(error)
|
||||
logError(new Error(`Failed to deduplicate LSP diagnostics: ${err.message}`))
|
||||
// Fall back to undedup'd files to avoid losing diagnostics
|
||||
dedupedFiles = allFiles
|
||||
}
|
||||
|
||||
// Only mark as sent AFTER successful deduplication, then delete from map.
|
||||
// Entries are tracked in deliveredDiagnostics LRU for dedup, so we don't
|
||||
// need to keep them in pendingDiagnostics after delivery.
|
||||
for (const diagnostic of diagnosticsToMark) {
|
||||
diagnostic.attachmentSent = true
|
||||
}
|
||||
for (const [id, diagnostic] of pendingDiagnostics) {
|
||||
if (diagnostic.attachmentSent) {
|
||||
pendingDiagnostics.delete(id)
|
||||
}
|
||||
}
|
||||
|
||||
const originalCount = allFiles.reduce(
|
||||
(sum, f) => sum + f.diagnostics.length,
|
||||
0,
|
||||
)
|
||||
const dedupedCount = dedupedFiles.reduce(
|
||||
(sum, f) => sum + f.diagnostics.length,
|
||||
0,
|
||||
)
|
||||
|
||||
if (originalCount > dedupedCount) {
|
||||
logForDebugging(
|
||||
`LSP Diagnostics: Deduplication removed ${originalCount - dedupedCount} duplicate diagnostic(s)`,
|
||||
)
|
||||
}
|
||||
|
||||
// Apply volume limiting: cap per file and total
|
||||
let totalDiagnostics = 0
|
||||
let truncatedCount = 0
|
||||
for (const file of dedupedFiles) {
|
||||
// Sort by severity (Error=1 < Warning=2 < Info=3 < Hint=4) to prioritize errors
|
||||
file.diagnostics.sort(
|
||||
(a, b) => severityToNumber(a.severity) - severityToNumber(b.severity),
|
||||
)
|
||||
|
||||
// Cap per file
|
||||
if (file.diagnostics.length > MAX_DIAGNOSTICS_PER_FILE) {
|
||||
truncatedCount += file.diagnostics.length - MAX_DIAGNOSTICS_PER_FILE
|
||||
file.diagnostics = file.diagnostics.slice(0, MAX_DIAGNOSTICS_PER_FILE)
|
||||
}
|
||||
|
||||
// Cap total
|
||||
const remainingCapacity = MAX_TOTAL_DIAGNOSTICS - totalDiagnostics
|
||||
if (file.diagnostics.length > remainingCapacity) {
|
||||
truncatedCount += file.diagnostics.length - remainingCapacity
|
||||
file.diagnostics = file.diagnostics.slice(0, remainingCapacity)
|
||||
}
|
||||
|
||||
totalDiagnostics += file.diagnostics.length
|
||||
}
|
||||
|
||||
// Filter out files that ended up with no diagnostics after limiting
|
||||
dedupedFiles = dedupedFiles.filter(f => f.diagnostics.length > 0)
|
||||
|
||||
if (truncatedCount > 0) {
|
||||
logForDebugging(
|
||||
`LSP Diagnostics: Volume limiting removed ${truncatedCount} diagnostic(s) (max ${MAX_DIAGNOSTICS_PER_FILE}/file, ${MAX_TOTAL_DIAGNOSTICS} total)`,
|
||||
)
|
||||
}
|
||||
|
||||
// Track delivered diagnostics for cross-turn deduplication
|
||||
for (const file of dedupedFiles) {
|
||||
if (!deliveredDiagnostics.has(file.uri)) {
|
||||
deliveredDiagnostics.set(file.uri, new Set())
|
||||
}
|
||||
const delivered = deliveredDiagnostics.get(file.uri)!
|
||||
for (const diag of file.diagnostics) {
|
||||
try {
|
||||
delivered.add(createDiagnosticKey(diag))
|
||||
} catch (error: unknown) {
|
||||
// Log but continue - failure to track shouldn't prevent delivery
|
||||
const err = toError(error)
|
||||
const truncatedMessage =
|
||||
diag.message?.substring(0, 100) || '<no message>'
|
||||
logError(
|
||||
new Error(
|
||||
`Failed to track delivered diagnostic in ${file.uri}: ${err.message}. ` +
|
||||
`Diagnostic message: ${truncatedMessage}`,
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const finalCount = dedupedFiles.reduce(
|
||||
(sum, f) => sum + f.diagnostics.length,
|
||||
0,
|
||||
)
|
||||
|
||||
// Return empty if no diagnostics to deliver (all filtered by deduplication)
|
||||
if (finalCount === 0) {
|
||||
logForDebugging(
|
||||
`LSP Diagnostics: No new diagnostics to deliver (all filtered by deduplication)`,
|
||||
)
|
||||
return []
|
||||
}
|
||||
|
||||
logForDebugging(
|
||||
`LSP Diagnostics: Delivering ${dedupedFiles.length} file(s) with ${finalCount} diagnostic(s) from ${serverNames.size} server(s)`,
|
||||
)
|
||||
|
||||
// Return single result with all deduplicated diagnostics
|
||||
return [
|
||||
{
|
||||
serverName: Array.from(serverNames).join(', '),
|
||||
files: dedupedFiles,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear all pending diagnostics.
|
||||
* Used during cleanup/shutdown or for testing.
|
||||
* Note: Does NOT clear deliveredDiagnostics - that's for cross-turn deduplication
|
||||
* and should only be cleared when files are edited or on session reset.
|
||||
*/
|
||||
export function clearAllLSPDiagnostics(): void {
|
||||
logForDebugging(
|
||||
`LSP Diagnostics: Clearing ${pendingDiagnostics.size} pending diagnostic(s)`,
|
||||
)
|
||||
pendingDiagnostics.clear()
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset all diagnostic state including cross-turn tracking.
|
||||
* Used on session reset or for testing.
|
||||
*/
|
||||
export function resetAllLSPDiagnosticState(): void {
|
||||
logForDebugging(
|
||||
`LSP Diagnostics: Resetting all state (${pendingDiagnostics.size} pending, ${deliveredDiagnostics.size} files tracked)`,
|
||||
)
|
||||
pendingDiagnostics.clear()
|
||||
deliveredDiagnostics.clear()
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear delivered diagnostics for a specific file.
|
||||
* Should be called when a file is edited so that new diagnostics for that file
|
||||
* will be shown even if they match previously delivered ones.
|
||||
*
|
||||
* @param fileUri - URI of the file that was edited
|
||||
*/
|
||||
export function clearDeliveredDiagnosticsForFile(fileUri: string): void {
|
||||
if (deliveredDiagnostics.has(fileUri)) {
|
||||
logForDebugging(
|
||||
`LSP Diagnostics: Clearing delivered diagnostics for ${fileUri}`,
|
||||
)
|
||||
deliveredDiagnostics.delete(fileUri)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get count of pending diagnostics (for monitoring)
|
||||
*/
|
||||
export function getPendingLSPDiagnosticCount(): number {
|
||||
return pendingDiagnostics.size
|
||||
}
|
||||
511
src/services/lsp/LSPServerInstance.ts
Normal file
511
src/services/lsp/LSPServerInstance.ts
Normal file
@@ -0,0 +1,511 @@
|
||||
import * as path from 'path'
|
||||
import { pathToFileURL } from 'url'
|
||||
import type { InitializeParams } from 'vscode-languageserver-protocol'
|
||||
import { getCwd } from '../../utils/cwd.js'
|
||||
import { logForDebugging } from '../../utils/debug.js'
|
||||
import { errorMessage } from '../../utils/errors.js'
|
||||
import { logError } from '../../utils/log.js'
|
||||
import { sleep } from '../../utils/sleep.js'
|
||||
import type { createLSPClient as createLSPClientType } from './LSPClient.js'
|
||||
import type { LspServerState, ScopedLspServerConfig } from './types.js'
|
||||
|
||||
/**
|
||||
* LSP error code for "content modified" - indicates the server's state changed
|
||||
* during request processing (e.g., rust-analyzer still indexing the project).
|
||||
* This is a transient error that can be retried.
|
||||
*/
|
||||
const LSP_ERROR_CONTENT_MODIFIED = -32801
|
||||
|
||||
/**
|
||||
* Maximum number of retries for transient LSP errors like "content modified".
|
||||
*/
|
||||
const MAX_RETRIES_FOR_TRANSIENT_ERRORS = 3
|
||||
|
||||
/**
|
||||
* Base delay in milliseconds for exponential backoff on transient errors.
|
||||
* Actual delays: 500ms, 1000ms, 2000ms
|
||||
*/
|
||||
const RETRY_BASE_DELAY_MS = 500
|
||||
/**
|
||||
* LSP server instance interface returned by createLSPServerInstance.
|
||||
* Manages the lifecycle of a single LSP server with state tracking and health monitoring.
|
||||
*/
|
||||
export type LSPServerInstance = {
|
||||
/** Unique server identifier */
|
||||
readonly name: string
|
||||
/** Server configuration */
|
||||
readonly config: ScopedLspServerConfig
|
||||
/** Current server state */
|
||||
readonly state: LspServerState
|
||||
/** When the server was last started */
|
||||
readonly startTime: Date | undefined
|
||||
/** Last error encountered */
|
||||
readonly lastError: Error | undefined
|
||||
/** Number of times restart() has been called */
|
||||
readonly restartCount: number
|
||||
/** Start the server and initialize it */
|
||||
start(): Promise<void>
|
||||
/** Stop the server gracefully */
|
||||
stop(): Promise<void>
|
||||
/** Manually restart the server (stop then start) */
|
||||
restart(): Promise<void>
|
||||
/** Check if server is healthy and ready for requests */
|
||||
isHealthy(): boolean
|
||||
/** Send an LSP request to the server */
|
||||
sendRequest<T>(method: string, params: unknown): Promise<T>
|
||||
/** Send an LSP notification to the server (fire-and-forget) */
|
||||
sendNotification(method: string, params: unknown): Promise<void>
|
||||
/** Register a handler for LSP notifications */
|
||||
onNotification(method: string, handler: (params: unknown) => void): void
|
||||
/** Register a handler for LSP requests from the server */
|
||||
onRequest<TParams, TResult>(
|
||||
method: string,
|
||||
handler: (params: TParams) => TResult | Promise<TResult>,
|
||||
): void
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates and manages a single LSP server instance.
|
||||
*
|
||||
* Uses factory function pattern with closures for state encapsulation (avoiding classes).
|
||||
* Provides state tracking, health monitoring, and request forwarding for an LSP server.
|
||||
* Supports manual restart with configurable retry limits.
|
||||
*
|
||||
* State machine transitions:
|
||||
* - stopped → starting → running
|
||||
* - running → stopping → stopped
|
||||
* - any → error (on failure)
|
||||
* - error → starting (on retry)
|
||||
*
|
||||
* @param name - Unique identifier for this server instance
|
||||
* @param config - Server configuration including command, args, and limits
|
||||
* @returns LSP server instance with lifecycle management methods
|
||||
*
|
||||
* @example
|
||||
* const instance = createLSPServerInstance('my-server', config)
|
||||
* await instance.start()
|
||||
* const result = await instance.sendRequest('textDocument/definition', params)
|
||||
* await instance.stop()
|
||||
*/
|
||||
export function createLSPServerInstance(
|
||||
name: string,
|
||||
config: ScopedLspServerConfig,
|
||||
): LSPServerInstance {
|
||||
// Validate that unimplemented fields are not set
|
||||
if (config.restartOnCrash !== undefined) {
|
||||
throw new Error(
|
||||
`LSP server '${name}': restartOnCrash is not yet implemented. Remove this field from the configuration.`,
|
||||
)
|
||||
}
|
||||
if (config.shutdownTimeout !== undefined) {
|
||||
throw new Error(
|
||||
`LSP server '${name}': shutdownTimeout is not yet implemented. Remove this field from the configuration.`,
|
||||
)
|
||||
}
|
||||
|
||||
// Private state encapsulated via closures. Lazy-require LSPClient so
|
||||
// vscode-jsonrpc (~129KB) only loads when an LSP server is actually
|
||||
// instantiated, not when the static import chain reaches this module.
|
||||
// eslint-disable-next-line @typescript-eslint/no-require-imports
|
||||
const { createLSPClient } = require('./LSPClient.js') as {
|
||||
createLSPClient: typeof createLSPClientType
|
||||
}
|
||||
let state: LspServerState = 'stopped'
|
||||
let startTime: Date | undefined
|
||||
let lastError: Error | undefined
|
||||
let restartCount = 0
|
||||
let crashRecoveryCount = 0
|
||||
// Propagate crash state so ensureServerStarted can restart on next use.
|
||||
// Without this, state stays 'running' after crash and the server is never
|
||||
// restarted (zombie state).
|
||||
const client = createLSPClient(name, error => {
|
||||
state = 'error'
|
||||
lastError = error
|
||||
crashRecoveryCount++
|
||||
})
|
||||
|
||||
/**
|
||||
* Starts the LSP server and initializes it with workspace information.
|
||||
*
|
||||
* If the server is already running or starting, this method returns immediately.
|
||||
* On failure, sets state to 'error', logs for monitoring, and throws.
|
||||
*
|
||||
* @throws {Error} If server fails to start or initialize
|
||||
*/
|
||||
async function start(): Promise<void> {
|
||||
if (state === 'running' || state === 'starting') {
|
||||
return
|
||||
}
|
||||
|
||||
// Cap crash-recovery attempts so a persistently crashing server doesn't
|
||||
// spawn unbounded child processes on every incoming request.
|
||||
const maxRestarts = config.maxRestarts ?? 3
|
||||
if (state === 'error' && crashRecoveryCount > maxRestarts) {
|
||||
const error = new Error(
|
||||
`LSP server '${name}' exceeded max crash recovery attempts (${maxRestarts})`,
|
||||
)
|
||||
lastError = error
|
||||
logError(error)
|
||||
throw error
|
||||
}
|
||||
|
||||
let initPromise: Promise<unknown> | undefined
|
||||
try {
|
||||
state = 'starting'
|
||||
logForDebugging(`Starting LSP server instance: ${name}`)
|
||||
|
||||
// Start the client
|
||||
await client.start(config.command, config.args || [], {
|
||||
env: config.env,
|
||||
cwd: config.workspaceFolder,
|
||||
})
|
||||
|
||||
// Initialize with workspace info
|
||||
const workspaceFolder = config.workspaceFolder || getCwd()
|
||||
const workspaceUri = pathToFileURL(workspaceFolder).href
|
||||
|
||||
const initParams: InitializeParams = {
|
||||
processId: process.pid,
|
||||
|
||||
// Pass server-specific initialization options from plugin config
|
||||
// Required by vue-language-server, optional for others
|
||||
// Provide empty object as default to avoid undefined errors in servers
|
||||
// that expect this field to exist
|
||||
initializationOptions: config.initializationOptions ?? {},
|
||||
|
||||
// Modern approach (LSP 3.16+) - required for Pyright, gopls
|
||||
workspaceFolders: [
|
||||
{
|
||||
uri: workspaceUri,
|
||||
name: path.basename(workspaceFolder),
|
||||
},
|
||||
],
|
||||
|
||||
// Deprecated fields - some servers still need these for proper URI resolution
|
||||
rootPath: workspaceFolder, // Deprecated in LSP 3.8 but needed by some servers
|
||||
rootUri: workspaceUri, // Deprecated in LSP 3.16 but needed by typescript-language-server for goToDefinition
|
||||
|
||||
// Client capabilities - declare what features we support
|
||||
capabilities: {
|
||||
workspace: {
|
||||
// Don't claim to support workspace/configuration since we don't implement it
|
||||
// This prevents servers from requesting config we can't provide
|
||||
configuration: false,
|
||||
// Don't claim to support workspace folders changes since we don't handle
|
||||
// workspace/didChangeWorkspaceFolders notifications
|
||||
workspaceFolders: false,
|
||||
},
|
||||
textDocument: {
|
||||
synchronization: {
|
||||
dynamicRegistration: false,
|
||||
willSave: false,
|
||||
willSaveWaitUntil: false,
|
||||
didSave: true,
|
||||
},
|
||||
publishDiagnostics: {
|
||||
relatedInformation: true,
|
||||
tagSupport: {
|
||||
valueSet: [1, 2], // Unnecessary (1), Deprecated (2)
|
||||
},
|
||||
versionSupport: false,
|
||||
codeDescriptionSupport: true,
|
||||
dataSupport: false,
|
||||
},
|
||||
hover: {
|
||||
dynamicRegistration: false,
|
||||
contentFormat: ['markdown', 'plaintext'],
|
||||
},
|
||||
definition: {
|
||||
dynamicRegistration: false,
|
||||
linkSupport: true,
|
||||
},
|
||||
references: {
|
||||
dynamicRegistration: false,
|
||||
},
|
||||
documentSymbol: {
|
||||
dynamicRegistration: false,
|
||||
hierarchicalDocumentSymbolSupport: true,
|
||||
},
|
||||
callHierarchy: {
|
||||
dynamicRegistration: false,
|
||||
},
|
||||
},
|
||||
general: {
|
||||
positionEncodings: ['utf-16'],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
initPromise = client.initialize(initParams)
|
||||
if (config.startupTimeout !== undefined) {
|
||||
await withTimeout(
|
||||
initPromise,
|
||||
config.startupTimeout,
|
||||
`LSP server '${name}' timed out after ${config.startupTimeout}ms during initialization`,
|
||||
)
|
||||
} else {
|
||||
await initPromise
|
||||
}
|
||||
|
||||
state = 'running'
|
||||
startTime = new Date()
|
||||
crashRecoveryCount = 0
|
||||
logForDebugging(`LSP server instance started: ${name}`)
|
||||
} catch (error) {
|
||||
// Clean up the spawned child process on timeout/error
|
||||
client.stop().catch(() => {})
|
||||
// Prevent unhandled rejection from abandoned initialize promise
|
||||
initPromise?.catch(() => {})
|
||||
state = 'error'
|
||||
lastError = error as Error
|
||||
logError(error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Stops the LSP server gracefully.
|
||||
*
|
||||
* If already stopped or stopping, returns immediately.
|
||||
* On failure, sets state to 'error', logs for monitoring, and throws.
|
||||
*
|
||||
* @throws {Error} If server fails to stop
|
||||
*/
|
||||
async function stop(): Promise<void> {
|
||||
if (state === 'stopped' || state === 'stopping') {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
state = 'stopping'
|
||||
await client.stop()
|
||||
state = 'stopped'
|
||||
logForDebugging(`LSP server instance stopped: ${name}`)
|
||||
} catch (error) {
|
||||
state = 'error'
|
||||
lastError = error as Error
|
||||
logError(error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Manually restarts the server by stopping and starting it.
|
||||
*
|
||||
* Increments restartCount and enforces maxRestarts limit.
|
||||
* Note: This is NOT automatic - must be called explicitly.
|
||||
*
|
||||
* @throws {Error} If stop or start fails, or if restartCount exceeds config.maxRestarts (default: 3)
|
||||
*/
|
||||
async function restart(): Promise<void> {
|
||||
try {
|
||||
await stop()
|
||||
} catch (error) {
|
||||
const stopError = new Error(
|
||||
`Failed to stop LSP server '${name}' during restart: ${errorMessage(error)}`,
|
||||
)
|
||||
logError(stopError)
|
||||
throw stopError
|
||||
}
|
||||
|
||||
restartCount++
|
||||
|
||||
const maxRestarts = config.maxRestarts ?? 3
|
||||
if (restartCount > maxRestarts) {
|
||||
const error = new Error(
|
||||
`Max restart attempts (${maxRestarts}) exceeded for server '${name}'`,
|
||||
)
|
||||
logError(error)
|
||||
throw error
|
||||
}
|
||||
|
||||
try {
|
||||
await start()
|
||||
} catch (error) {
|
||||
const startError = new Error(
|
||||
`Failed to start LSP server '${name}' during restart (attempt ${restartCount}/${maxRestarts}): ${errorMessage(error)}`,
|
||||
)
|
||||
logError(startError)
|
||||
throw startError
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the server is healthy and ready to handle requests.
|
||||
*
|
||||
* @returns true if state is 'running' AND the client has completed initialization
|
||||
*/
|
||||
function isHealthy(): boolean {
|
||||
return state === 'running' && client.isInitialized
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends an LSP request to the server with retry logic for transient errors.
|
||||
*
|
||||
* Checks server health before sending and wraps errors with context.
|
||||
* Automatically retries on "content modified" errors (code -32801) which occur
|
||||
* when servers like rust-analyzer are still indexing. This is expected LSP behavior
|
||||
* and clients should retry silently per the LSP specification.
|
||||
*
|
||||
* @param method - LSP method name (e.g., 'textDocument/definition')
|
||||
* @param params - Method-specific parameters
|
||||
* @returns The server's response
|
||||
* @throws {Error} If server is not healthy or request fails after all retries
|
||||
*/
|
||||
async function sendRequest<T>(method: string, params: unknown): Promise<T> {
|
||||
if (!isHealthy()) {
|
||||
const error = new Error(
|
||||
`Cannot send request to LSP server '${name}': server is ${state}` +
|
||||
`${lastError ? `, last error: ${lastError.message}` : ''}`,
|
||||
)
|
||||
logError(error)
|
||||
throw error
|
||||
}
|
||||
|
||||
let lastAttemptError: Error | undefined
|
||||
|
||||
for (
|
||||
let attempt = 0;
|
||||
attempt <= MAX_RETRIES_FOR_TRANSIENT_ERRORS;
|
||||
attempt++
|
||||
) {
|
||||
try {
|
||||
return await client.sendRequest(method, params)
|
||||
} catch (error) {
|
||||
lastAttemptError = error as Error
|
||||
|
||||
// Check if this is a transient "content modified" error that we should retry
|
||||
// This commonly happens with rust-analyzer during initial project indexing.
|
||||
// We use duck typing instead of instanceof because there may be multiple
|
||||
// versions of vscode-jsonrpc in the dependency tree (8.2.0 vs 8.2.1).
|
||||
const errorCode = (error as { code?: number }).code
|
||||
const isContentModifiedError =
|
||||
typeof errorCode === 'number' &&
|
||||
errorCode === LSP_ERROR_CONTENT_MODIFIED
|
||||
|
||||
if (
|
||||
isContentModifiedError &&
|
||||
attempt < MAX_RETRIES_FOR_TRANSIENT_ERRORS
|
||||
) {
|
||||
const delay = RETRY_BASE_DELAY_MS * Math.pow(2, attempt)
|
||||
logForDebugging(
|
||||
`LSP request '${method}' to '${name}' got ContentModified error, ` +
|
||||
`retrying in ${delay}ms (attempt ${attempt + 1}/${MAX_RETRIES_FOR_TRANSIENT_ERRORS})…`,
|
||||
)
|
||||
await sleep(delay)
|
||||
continue
|
||||
}
|
||||
|
||||
// Non-retryable error or max retries exceeded
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// All retries failed or non-retryable error
|
||||
const requestError = new Error(
|
||||
`LSP request '${method}' failed for server '${name}': ${lastAttemptError?.message ?? 'unknown error'}`,
|
||||
)
|
||||
logError(requestError)
|
||||
throw requestError
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a notification to the LSP server (fire-and-forget).
|
||||
* Used for file synchronization (didOpen, didChange, didClose).
|
||||
*/
|
||||
async function sendNotification(
|
||||
method: string,
|
||||
params: unknown,
|
||||
): Promise<void> {
|
||||
if (!isHealthy()) {
|
||||
const error = new Error(
|
||||
`Cannot send notification to LSP server '${name}': server is ${state}`,
|
||||
)
|
||||
logError(error)
|
||||
throw error
|
||||
}
|
||||
|
||||
try {
|
||||
await client.sendNotification(method, params)
|
||||
} catch (error) {
|
||||
const notificationError = new Error(
|
||||
`LSP notification '${method}' failed for server '${name}': ${errorMessage(error)}`,
|
||||
)
|
||||
logError(notificationError)
|
||||
throw notificationError
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Registers a handler for LSP notifications from the server.
|
||||
*
|
||||
* @param method - LSP notification method (e.g., 'window/logMessage')
|
||||
* @param handler - Callback function to handle the notification
|
||||
*/
|
||||
function onNotification(
|
||||
method: string,
|
||||
handler: (params: unknown) => void,
|
||||
): void {
|
||||
client.onNotification(method, handler)
|
||||
}
|
||||
|
||||
/**
|
||||
* Registers a handler for LSP requests from the server.
|
||||
*
|
||||
* Some LSP servers send requests TO the client (reverse direction).
|
||||
* This allows registering handlers for such requests.
|
||||
*
|
||||
* @param method - LSP request method (e.g., 'workspace/configuration')
|
||||
* @param handler - Callback function to handle the request and return a response
|
||||
*/
|
||||
function onRequest<TParams, TResult>(
|
||||
method: string,
|
||||
handler: (params: TParams) => TResult | Promise<TResult>,
|
||||
): void {
|
||||
client.onRequest(method, handler)
|
||||
}
|
||||
|
||||
// Return public API
|
||||
return {
|
||||
name,
|
||||
config,
|
||||
get state() {
|
||||
return state
|
||||
},
|
||||
get startTime() {
|
||||
return startTime
|
||||
},
|
||||
get lastError() {
|
||||
return lastError
|
||||
},
|
||||
get restartCount() {
|
||||
return restartCount
|
||||
},
|
||||
start,
|
||||
stop,
|
||||
restart,
|
||||
isHealthy,
|
||||
sendRequest,
|
||||
sendNotification,
|
||||
onNotification,
|
||||
onRequest,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Race a promise against a timeout. Cleans up the timer regardless of outcome
|
||||
* to avoid unhandled rejections from orphaned setTimeout callbacks.
|
||||
*/
|
||||
function withTimeout<T>(
|
||||
promise: Promise<T>,
|
||||
ms: number,
|
||||
message: string,
|
||||
): Promise<T> {
|
||||
let timer: ReturnType<typeof setTimeout>
|
||||
const timeoutPromise = new Promise<never>((_, reject) => {
|
||||
timer = setTimeout((rej, msg) => rej(new Error(msg)), ms, reject, message)
|
||||
})
|
||||
return Promise.race([promise, timeoutPromise]).finally(() =>
|
||||
clearTimeout(timer!),
|
||||
)
|
||||
}
|
||||
420
src/services/lsp/LSPServerManager.ts
Normal file
420
src/services/lsp/LSPServerManager.ts
Normal file
@@ -0,0 +1,420 @@
|
||||
import * as path from 'path'
|
||||
import { pathToFileURL } from 'url'
|
||||
import { logForDebugging } from '../../utils/debug.js'
|
||||
import { errorMessage } from '../../utils/errors.js'
|
||||
import { logError } from '../../utils/log.js'
|
||||
import { getAllLspServers } from './config.js'
|
||||
import {
|
||||
createLSPServerInstance,
|
||||
type LSPServerInstance,
|
||||
} from './LSPServerInstance.js'
|
||||
import type { ScopedLspServerConfig } from './types.js'
|
||||
/**
|
||||
* LSP Server Manager interface returned by createLSPServerManager.
|
||||
* Manages multiple LSP server instances and routes requests based on file extensions.
|
||||
*/
|
||||
export type LSPServerManager = {
|
||||
/** Initialize the manager by loading all configured LSP servers */
|
||||
initialize(): Promise<void>
|
||||
/** Shutdown all running servers and clear state */
|
||||
shutdown(): Promise<void>
|
||||
/** Get the LSP server instance for a given file path */
|
||||
getServerForFile(filePath: string): LSPServerInstance | undefined
|
||||
/** Ensure the appropriate LSP server is started for the given file */
|
||||
ensureServerStarted(filePath: string): Promise<LSPServerInstance | undefined>
|
||||
/** Send a request to the appropriate LSP server for the given file */
|
||||
sendRequest<T>(
|
||||
filePath: string,
|
||||
method: string,
|
||||
params: unknown,
|
||||
): Promise<T | undefined>
|
||||
/** Get all running server instances */
|
||||
getAllServers(): Map<string, LSPServerInstance>
|
||||
/** Synchronize file open to LSP server (sends didOpen notification) */
|
||||
openFile(filePath: string, content: string): Promise<void>
|
||||
/** Synchronize file change to LSP server (sends didChange notification) */
|
||||
changeFile(filePath: string, content: string): Promise<void>
|
||||
/** Synchronize file save to LSP server (sends didSave notification) */
|
||||
saveFile(filePath: string): Promise<void>
|
||||
/** Synchronize file close to LSP server (sends didClose notification) */
|
||||
closeFile(filePath: string): Promise<void>
|
||||
/** Check if a file is already open on a compatible LSP server */
|
||||
isFileOpen(filePath: string): boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an LSP server manager instance.
|
||||
*
|
||||
* Manages multiple LSP server instances and routes requests based on file extensions.
|
||||
* Uses factory function pattern with closures for state encapsulation (avoiding classes).
|
||||
*
|
||||
* @returns LSP server manager instance
|
||||
*
|
||||
* @example
|
||||
* const manager = createLSPServerManager()
|
||||
* await manager.initialize()
|
||||
* const result = await manager.sendRequest('/path/to/file.ts', 'textDocument/definition', params)
|
||||
* await manager.shutdown()
|
||||
*/
|
||||
export function createLSPServerManager(): LSPServerManager {
|
||||
// Private state managed via closures
|
||||
const servers: Map<string, LSPServerInstance> = new Map()
|
||||
const extensionMap: Map<string, string[]> = new Map()
|
||||
// Track which files have been opened on which servers (URI -> server name)
|
||||
const openedFiles: Map<string, string> = new Map()
|
||||
|
||||
/**
|
||||
* Initialize the manager by loading all configured LSP servers.
|
||||
*
|
||||
* @throws {Error} If configuration loading fails
|
||||
*/
|
||||
async function initialize(): Promise<void> {
|
||||
let serverConfigs: Record<string, ScopedLspServerConfig>
|
||||
|
||||
try {
|
||||
const result = await getAllLspServers()
|
||||
serverConfigs = result.servers
|
||||
logForDebugging(
|
||||
`[LSP SERVER MANAGER] getAllLspServers returned ${Object.keys(serverConfigs).length} server(s)`,
|
||||
)
|
||||
} catch (error) {
|
||||
const err = error as Error
|
||||
logError(
|
||||
new Error(`Failed to load LSP server configuration: ${err.message}`),
|
||||
)
|
||||
throw error
|
||||
}
|
||||
|
||||
// Build extension → server mapping
|
||||
for (const [serverName, config] of Object.entries(serverConfigs)) {
|
||||
try {
|
||||
// Validate config before using it
|
||||
if (!config.command) {
|
||||
throw new Error(
|
||||
`Server ${serverName} missing required 'command' field`,
|
||||
)
|
||||
}
|
||||
if (
|
||||
!config.extensionToLanguage ||
|
||||
Object.keys(config.extensionToLanguage).length === 0
|
||||
) {
|
||||
throw new Error(
|
||||
`Server ${serverName} missing required 'extensionToLanguage' field`,
|
||||
)
|
||||
}
|
||||
|
||||
// Map file extensions to this server (derive from extensionToLanguage)
|
||||
const fileExtensions = Object.keys(config.extensionToLanguage)
|
||||
for (const ext of fileExtensions) {
|
||||
const normalized = ext.toLowerCase()
|
||||
if (!extensionMap.has(normalized)) {
|
||||
extensionMap.set(normalized, [])
|
||||
}
|
||||
const serverList = extensionMap.get(normalized)
|
||||
if (serverList) {
|
||||
serverList.push(serverName)
|
||||
}
|
||||
}
|
||||
|
||||
// Create server instance
|
||||
const instance = createLSPServerInstance(serverName, config)
|
||||
servers.set(serverName, instance)
|
||||
|
||||
// Register handler for workspace/configuration requests from the server
|
||||
// Some servers (like TypeScript) send these even when we say we don't support them
|
||||
instance.onRequest(
|
||||
'workspace/configuration',
|
||||
(params: { items: Array<{ section?: string }> }) => {
|
||||
logForDebugging(
|
||||
`LSP: Received workspace/configuration request from ${serverName}`,
|
||||
)
|
||||
// Return empty/null config for each requested item
|
||||
// This satisfies the protocol without providing actual configuration
|
||||
return params.items.map(() => null)
|
||||
},
|
||||
)
|
||||
} catch (error) {
|
||||
const err = error as Error
|
||||
logError(
|
||||
new Error(
|
||||
`Failed to initialize LSP server ${serverName}: ${err.message}`,
|
||||
),
|
||||
)
|
||||
// Continue with other servers - don't fail entire initialization
|
||||
}
|
||||
}
|
||||
|
||||
logForDebugging(`LSP manager initialized with ${servers.size} servers`)
|
||||
}
|
||||
|
||||
/**
|
||||
* Shutdown all running servers and clear state.
|
||||
* Only servers in 'running' state are explicitly stopped;
|
||||
* servers in other states are cleared without shutdown.
|
||||
*
|
||||
* @throws {Error} If one or more servers fail to stop
|
||||
*/
|
||||
async function shutdown(): Promise<void> {
|
||||
const toStop = Array.from(servers.entries()).filter(
|
||||
([, s]) => s.state === 'running' || s.state === 'error',
|
||||
)
|
||||
|
||||
const results = await Promise.allSettled(
|
||||
toStop.map(([, server]) => server.stop()),
|
||||
)
|
||||
|
||||
servers.clear()
|
||||
extensionMap.clear()
|
||||
openedFiles.clear()
|
||||
|
||||
const errors = results
|
||||
.map((r, i) =>
|
||||
r.status === 'rejected'
|
||||
? `${toStop[i]![0]}: ${errorMessage(r.reason)}`
|
||||
: null,
|
||||
)
|
||||
.filter((e): e is string => e !== null)
|
||||
|
||||
if (errors.length > 0) {
|
||||
const err = new Error(
|
||||
`Failed to stop ${errors.length} LSP server(s): ${errors.join('; ')}`,
|
||||
)
|
||||
logError(err)
|
||||
throw err
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the LSP server instance for a given file path.
|
||||
* If multiple servers handle the same extension, returns the first registered server.
|
||||
* Returns undefined if no server handles this file type.
|
||||
*/
|
||||
function getServerForFile(filePath: string): LSPServerInstance | undefined {
|
||||
const ext = path.extname(filePath).toLowerCase()
|
||||
const serverNames = extensionMap.get(ext)
|
||||
|
||||
if (!serverNames || serverNames.length === 0) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
// Use first server (can add priority later)
|
||||
const serverName = serverNames[0]
|
||||
if (!serverName) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
return servers.get(serverName)
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensure the appropriate LSP server is started for the given file.
|
||||
* Returns undefined if no server handles this file type.
|
||||
*
|
||||
* @throws {Error} If server fails to start
|
||||
*/
|
||||
async function ensureServerStarted(
|
||||
filePath: string,
|
||||
): Promise<LSPServerInstance | undefined> {
|
||||
const server = getServerForFile(filePath)
|
||||
if (!server) return undefined
|
||||
|
||||
if (server.state === 'stopped' || server.state === 'error') {
|
||||
try {
|
||||
await server.start()
|
||||
} catch (error) {
|
||||
const err = error as Error
|
||||
logError(
|
||||
new Error(
|
||||
`Failed to start LSP server for file ${filePath}: ${err.message}`,
|
||||
),
|
||||
)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
return server
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a request to the appropriate LSP server for the given file.
|
||||
* Returns undefined if no server handles this file type.
|
||||
*
|
||||
* @throws {Error} If server fails to start or request fails
|
||||
*/
|
||||
async function sendRequest<T>(
|
||||
filePath: string,
|
||||
method: string,
|
||||
params: unknown,
|
||||
): Promise<T | undefined> {
|
||||
const server = await ensureServerStarted(filePath)
|
||||
if (!server) return undefined
|
||||
|
||||
try {
|
||||
return await server.sendRequest<T>(method, params)
|
||||
} catch (error) {
|
||||
const err = error as Error
|
||||
logError(
|
||||
new Error(
|
||||
`LSP request failed for file ${filePath}, method '${method}': ${err.message}`,
|
||||
),
|
||||
)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// Return public interface
|
||||
function getAllServers(): Map<string, LSPServerInstance> {
|
||||
return servers
|
||||
}
|
||||
|
||||
async function openFile(filePath: string, content: string): Promise<void> {
|
||||
const server = await ensureServerStarted(filePath)
|
||||
if (!server) return
|
||||
|
||||
const fileUri = pathToFileURL(path.resolve(filePath)).href
|
||||
|
||||
// Skip if already opened on this server
|
||||
if (openedFiles.get(fileUri) === server.name) {
|
||||
logForDebugging(
|
||||
`LSP: File already open, skipping didOpen for ${filePath}`,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Get language ID from server's extensionToLanguage mapping
|
||||
const ext = path.extname(filePath).toLowerCase()
|
||||
const languageId = server.config.extensionToLanguage[ext] || 'plaintext'
|
||||
|
||||
try {
|
||||
await server.sendNotification('textDocument/didOpen', {
|
||||
textDocument: {
|
||||
uri: fileUri,
|
||||
languageId,
|
||||
version: 1,
|
||||
text: content,
|
||||
},
|
||||
})
|
||||
// Track that this file is now open on this server
|
||||
openedFiles.set(fileUri, server.name)
|
||||
logForDebugging(
|
||||
`LSP: Sent didOpen for ${filePath} (languageId: ${languageId})`,
|
||||
)
|
||||
} catch (error) {
|
||||
const err = new Error(
|
||||
`Failed to sync file open ${filePath}: ${errorMessage(error)}`,
|
||||
)
|
||||
logError(err)
|
||||
// Re-throw to propagate error to caller
|
||||
throw err
|
||||
}
|
||||
}
|
||||
|
||||
async function changeFile(filePath: string, content: string): Promise<void> {
|
||||
const server = getServerForFile(filePath)
|
||||
if (!server || server.state !== 'running') {
|
||||
return openFile(filePath, content)
|
||||
}
|
||||
|
||||
const fileUri = pathToFileURL(path.resolve(filePath)).href
|
||||
|
||||
// If file hasn't been opened on this server yet, open it first
|
||||
// LSP servers require didOpen before didChange
|
||||
if (openedFiles.get(fileUri) !== server.name) {
|
||||
return openFile(filePath, content)
|
||||
}
|
||||
|
||||
try {
|
||||
await server.sendNotification('textDocument/didChange', {
|
||||
textDocument: {
|
||||
uri: fileUri,
|
||||
version: 1,
|
||||
},
|
||||
contentChanges: [{ text: content }],
|
||||
})
|
||||
logForDebugging(`LSP: Sent didChange for ${filePath}`)
|
||||
} catch (error) {
|
||||
const err = new Error(
|
||||
`Failed to sync file change ${filePath}: ${errorMessage(error)}`,
|
||||
)
|
||||
logError(err)
|
||||
// Re-throw to propagate error to caller
|
||||
throw err
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Save a file in LSP servers (sends didSave notification)
|
||||
* Called after file is written to disk to trigger diagnostics
|
||||
*/
|
||||
async function saveFile(filePath: string): Promise<void> {
|
||||
const server = getServerForFile(filePath)
|
||||
if (!server || server.state !== 'running') return
|
||||
|
||||
try {
|
||||
await server.sendNotification('textDocument/didSave', {
|
||||
textDocument: {
|
||||
uri: pathToFileURL(path.resolve(filePath)).href,
|
||||
},
|
||||
})
|
||||
logForDebugging(`LSP: Sent didSave for ${filePath}`)
|
||||
} catch (error) {
|
||||
const err = new Error(
|
||||
`Failed to sync file save ${filePath}: ${errorMessage(error)}`,
|
||||
)
|
||||
logError(err)
|
||||
// Re-throw to propagate error to caller
|
||||
throw err
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Close a file in LSP servers (sends didClose notification)
|
||||
*
|
||||
* NOTE: Currently available but not yet integrated with compact flow.
|
||||
* TODO: Integrate with compact - call closeFile() when compact removes files from context
|
||||
* This will notify LSP servers that files are no longer in active use.
|
||||
*/
|
||||
async function closeFile(filePath: string): Promise<void> {
|
||||
const server = getServerForFile(filePath)
|
||||
if (!server || server.state !== 'running') return
|
||||
|
||||
const fileUri = pathToFileURL(path.resolve(filePath)).href
|
||||
|
||||
try {
|
||||
await server.sendNotification('textDocument/didClose', {
|
||||
textDocument: {
|
||||
uri: fileUri,
|
||||
},
|
||||
})
|
||||
// Remove from tracking so file can be reopened later
|
||||
openedFiles.delete(fileUri)
|
||||
logForDebugging(`LSP: Sent didClose for ${filePath}`)
|
||||
} catch (error) {
|
||||
const err = new Error(
|
||||
`Failed to sync file close ${filePath}: ${errorMessage(error)}`,
|
||||
)
|
||||
logError(err)
|
||||
// Re-throw to propagate error to caller
|
||||
throw err
|
||||
}
|
||||
}
|
||||
|
||||
function isFileOpen(filePath: string): boolean {
|
||||
const fileUri = pathToFileURL(path.resolve(filePath)).href
|
||||
return openedFiles.has(fileUri)
|
||||
}
|
||||
|
||||
return {
|
||||
initialize,
|
||||
shutdown,
|
||||
getServerForFile,
|
||||
ensureServerStarted,
|
||||
sendRequest,
|
||||
getAllServers,
|
||||
openFile,
|
||||
changeFile,
|
||||
saveFile,
|
||||
closeFile,
|
||||
isFileOpen,
|
||||
}
|
||||
}
|
||||
79
src/services/lsp/config.ts
Normal file
79
src/services/lsp/config.ts
Normal file
@@ -0,0 +1,79 @@
|
||||
import type { PluginError } from '../../types/plugin.js'
|
||||
import { logForDebugging } from '../../utils/debug.js'
|
||||
import { errorMessage, toError } from '../../utils/errors.js'
|
||||
import { logError } from '../../utils/log.js'
|
||||
import { getPluginLspServers } from '../../utils/plugins/lspPluginIntegration.js'
|
||||
import { loadAllPluginsCacheOnly } from '../../utils/plugins/pluginLoader.js'
|
||||
import type { ScopedLspServerConfig } from './types.js'
|
||||
|
||||
/**
|
||||
* Get all configured LSP servers from plugins.
|
||||
* LSP servers are only supported via plugins, not user/project settings.
|
||||
*
|
||||
* @returns Object containing servers configuration keyed by scoped server name
|
||||
*/
|
||||
export async function getAllLspServers(): Promise<{
|
||||
servers: Record<string, ScopedLspServerConfig>
|
||||
}> {
|
||||
const allServers: Record<string, ScopedLspServerConfig> = {}
|
||||
|
||||
try {
|
||||
// Get all enabled plugins
|
||||
const { enabled: plugins } = await loadAllPluginsCacheOnly()
|
||||
|
||||
// Load LSP servers from each plugin in parallel.
|
||||
// Each plugin is independent — results are merged in original order so
|
||||
// Object.assign collision precedence (later plugins win) is preserved.
|
||||
const results = await Promise.all(
|
||||
plugins.map(async plugin => {
|
||||
const errors: PluginError[] = []
|
||||
try {
|
||||
const scopedServers = await getPluginLspServers(plugin, errors)
|
||||
return { plugin, scopedServers, errors }
|
||||
} catch (e) {
|
||||
// Defensive: if one plugin throws, don't lose results from the
|
||||
// others. The previous serial loop implicitly tolerated this.
|
||||
logForDebugging(
|
||||
`Failed to load LSP servers for plugin ${plugin.name}: ${e}`,
|
||||
{ level: 'error' },
|
||||
)
|
||||
return { plugin, scopedServers: undefined, errors }
|
||||
}
|
||||
}),
|
||||
)
|
||||
|
||||
for (const { plugin, scopedServers, errors } of results) {
|
||||
const serverCount = scopedServers ? Object.keys(scopedServers).length : 0
|
||||
if (serverCount > 0) {
|
||||
// Merge into all servers (already scoped by getPluginLspServers)
|
||||
Object.assign(allServers, scopedServers)
|
||||
|
||||
logForDebugging(
|
||||
`Loaded ${serverCount} LSP server(s) from plugin: ${plugin.name}`,
|
||||
)
|
||||
}
|
||||
|
||||
// Log any errors encountered
|
||||
if (errors.length > 0) {
|
||||
logForDebugging(
|
||||
`${errors.length} error(s) loading LSP servers from plugin: ${plugin.name}`,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
logForDebugging(
|
||||
`Total LSP servers loaded: ${Object.keys(allServers).length}`,
|
||||
)
|
||||
} catch (error) {
|
||||
// Log error for monitoring production issues.
|
||||
// LSP is optional, so we don't throw - but we need visibility
|
||||
// into why plugin loading fails to improve the feature.
|
||||
logError(toError(error))
|
||||
|
||||
logForDebugging(`Error loading LSP servers: ${errorMessage(error)}`)
|
||||
}
|
||||
|
||||
return {
|
||||
servers: allServers,
|
||||
}
|
||||
}
|
||||
289
src/services/lsp/manager.ts
Normal file
289
src/services/lsp/manager.ts
Normal file
@@ -0,0 +1,289 @@
|
||||
import { logForDebugging } from '../../utils/debug.js'
|
||||
import { isBareMode } from '../../utils/envUtils.js'
|
||||
import { errorMessage } from '../../utils/errors.js'
|
||||
import { logError } from '../../utils/log.js'
|
||||
import {
|
||||
createLSPServerManager,
|
||||
type LSPServerManager,
|
||||
} from './LSPServerManager.js'
|
||||
import { registerLSPNotificationHandlers } from './passiveFeedback.js'
|
||||
|
||||
/**
|
||||
* Initialization state of the LSP server manager
|
||||
*/
|
||||
type InitializationState = 'not-started' | 'pending' | 'success' | 'failed'
|
||||
|
||||
/**
|
||||
* Global singleton instance of the LSP server manager.
|
||||
* Initialized during Claude Code startup.
|
||||
*/
|
||||
let lspManagerInstance: LSPServerManager | undefined
|
||||
|
||||
/**
|
||||
* Current initialization state
|
||||
*/
|
||||
let initializationState: InitializationState = 'not-started'
|
||||
|
||||
/**
|
||||
* Error from last initialization attempt, if any
|
||||
*/
|
||||
let initializationError: Error | undefined
|
||||
|
||||
/**
|
||||
* Generation counter to prevent stale initialization promises from updating state
|
||||
*/
|
||||
let initializationGeneration = 0
|
||||
|
||||
/**
|
||||
* Promise that resolves when initialization completes (success or failure)
|
||||
*/
|
||||
let initializationPromise: Promise<void> | undefined
|
||||
|
||||
/**
|
||||
* Test-only sync reset. shutdownLspServerManager() is async and tears down
|
||||
* real connections; this only clears the module-scope singleton state so
|
||||
* reinitializeLspServerManager() early-returns on 'not-started' in downstream
|
||||
* tests on the same shard.
|
||||
*/
|
||||
export function _resetLspManagerForTesting(): void {
|
||||
initializationState = 'not-started'
|
||||
initializationError = undefined
|
||||
initializationPromise = undefined
|
||||
initializationGeneration++
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the singleton LSP server manager instance.
|
||||
* Returns undefined if not yet initialized, initialization failed, or still pending.
|
||||
*
|
||||
* Callers should check for undefined and handle gracefully, as initialization happens
|
||||
* asynchronously during Claude Code startup. Use getInitializationStatus() to
|
||||
* distinguish between pending, failed, and not-started states.
|
||||
*/
|
||||
export function getLspServerManager(): LSPServerManager | undefined {
|
||||
// Don't return a broken instance if initialization failed
|
||||
if (initializationState === 'failed') {
|
||||
return undefined
|
||||
}
|
||||
return lspManagerInstance
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current initialization status of the LSP server manager.
|
||||
*
|
||||
* @returns Status object with current state and error (if failed)
|
||||
*/
|
||||
export function getInitializationStatus():
|
||||
| { status: 'not-started' }
|
||||
| { status: 'pending' }
|
||||
| { status: 'success' }
|
||||
| { status: 'failed'; error: Error } {
|
||||
if (initializationState === 'failed') {
|
||||
return {
|
||||
status: 'failed',
|
||||
error: initializationError || new Error('Initialization failed'),
|
||||
}
|
||||
}
|
||||
if (initializationState === 'not-started') {
|
||||
return { status: 'not-started' }
|
||||
}
|
||||
if (initializationState === 'pending') {
|
||||
return { status: 'pending' }
|
||||
}
|
||||
return { status: 'success' }
|
||||
}
|
||||
|
||||
/**
|
||||
* Check whether at least one language server is connected and healthy.
|
||||
* Backs LSPTool.isEnabled().
|
||||
*/
|
||||
export function isLspConnected(): boolean {
|
||||
if (initializationState === 'failed') return false
|
||||
const manager = getLspServerManager()
|
||||
if (!manager) return false
|
||||
const servers = manager.getAllServers()
|
||||
if (servers.size === 0) return false
|
||||
for (const server of servers.values()) {
|
||||
if (server.state !== 'error') return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
/**
|
||||
* Wait for LSP server manager initialization to complete.
|
||||
*
|
||||
* Returns immediately if initialization has already completed (success or failure).
|
||||
* If initialization is pending, waits for it to complete.
|
||||
* If initialization hasn't started, returns immediately.
|
||||
*
|
||||
* @returns Promise that resolves when initialization is complete
|
||||
*/
|
||||
export async function waitForInitialization(): Promise<void> {
|
||||
// If already initialized or failed, return immediately
|
||||
if (initializationState === 'success' || initializationState === 'failed') {
|
||||
return
|
||||
}
|
||||
|
||||
// If pending and we have a promise, wait for it
|
||||
if (initializationState === 'pending' && initializationPromise) {
|
||||
await initializationPromise
|
||||
}
|
||||
|
||||
// If not started, return immediately (nothing to wait for)
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize the LSP server manager singleton.
|
||||
*
|
||||
* This function is called during Claude Code startup. It synchronously creates
|
||||
* the manager instance, then starts async initialization (loading LSP configs)
|
||||
* in the background without blocking the startup process.
|
||||
*
|
||||
* Safe to call multiple times - will only initialize once (idempotent).
|
||||
* However, if initialization previously failed, calling again will retry.
|
||||
*/
|
||||
export function initializeLspServerManager(): void {
|
||||
// --bare / SIMPLE: no LSP. LSP is for editor integration (diagnostics,
|
||||
// hover, go-to-def in the REPL). Scripted -p calls have no use for it.
|
||||
if (isBareMode()) {
|
||||
return
|
||||
}
|
||||
logForDebugging('[LSP MANAGER] initializeLspServerManager() called')
|
||||
|
||||
// Skip if already initialized or currently initializing
|
||||
if (lspManagerInstance !== undefined && initializationState !== 'failed') {
|
||||
logForDebugging(
|
||||
'[LSP MANAGER] Already initialized or initializing, skipping',
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Reset state for retry if previous initialization failed
|
||||
if (initializationState === 'failed') {
|
||||
lspManagerInstance = undefined
|
||||
initializationError = undefined
|
||||
}
|
||||
|
||||
// Create the manager instance and mark as pending
|
||||
lspManagerInstance = createLSPServerManager()
|
||||
initializationState = 'pending'
|
||||
logForDebugging('[LSP MANAGER] Created manager instance, state=pending')
|
||||
|
||||
// Increment generation to invalidate any pending initializations
|
||||
const currentGeneration = ++initializationGeneration
|
||||
logForDebugging(
|
||||
`[LSP MANAGER] Starting async initialization (generation ${currentGeneration})`,
|
||||
)
|
||||
|
||||
// Start initialization asynchronously without blocking
|
||||
// Store the promise so callers can await it via waitForInitialization()
|
||||
initializationPromise = lspManagerInstance
|
||||
.initialize()
|
||||
.then(() => {
|
||||
// Only update state if this is still the current initialization
|
||||
if (currentGeneration === initializationGeneration) {
|
||||
initializationState = 'success'
|
||||
logForDebugging('LSP server manager initialized successfully')
|
||||
|
||||
// Register passive notification handlers for diagnostics
|
||||
if (lspManagerInstance) {
|
||||
registerLSPNotificationHandlers(lspManagerInstance)
|
||||
}
|
||||
}
|
||||
})
|
||||
.catch((error: unknown) => {
|
||||
// Only update state if this is still the current initialization
|
||||
if (currentGeneration === initializationGeneration) {
|
||||
initializationState = 'failed'
|
||||
initializationError = error as Error
|
||||
// Clear the instance since it's not usable
|
||||
lspManagerInstance = undefined
|
||||
|
||||
logError(error as Error)
|
||||
logForDebugging(
|
||||
`Failed to initialize LSP server manager: ${errorMessage(error)}`,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Force re-initialization of the LSP server manager, even after a prior
|
||||
* successful init. Called from refreshActivePlugins() after plugin caches
|
||||
* are cleared, so newly-loaded plugin LSP servers are picked up.
|
||||
*
|
||||
* Fixes https://github.com/anthropics/claude-code/issues/15521:
|
||||
* loadAllPlugins() is memoized and can be called very early in startup
|
||||
* (via getCommands prefetch in setup.ts) before marketplaces are reconciled,
|
||||
* caching an empty plugin list. initializeLspServerManager() then reads that
|
||||
* stale memoized result and initializes with 0 servers. Unlike commands/agents/
|
||||
* hooks/MCP, LSP was never re-initialized on plugin refresh.
|
||||
*
|
||||
* Safe to call when no LSP plugins changed: initialize() is just config
|
||||
* parsing (servers are lazy-started on first use). Also safe during pending
|
||||
* init: the generation counter invalidates the in-flight promise.
|
||||
*/
|
||||
export function reinitializeLspServerManager(): void {
|
||||
if (initializationState === 'not-started') {
|
||||
// initializeLspServerManager() was never called (e.g. headless subcommand
|
||||
// path). Don't start it now.
|
||||
return
|
||||
}
|
||||
|
||||
logForDebugging('[LSP MANAGER] reinitializeLspServerManager() called')
|
||||
|
||||
// Best-effort shutdown of any running servers on the old instance so
|
||||
// /reload-plugins doesn't leak child processes. Fire-and-forget: the
|
||||
// primary use case (issue #15521) has 0 servers so this is usually a no-op.
|
||||
if (lspManagerInstance) {
|
||||
void lspManagerInstance.shutdown().catch(err => {
|
||||
logForDebugging(
|
||||
`[LSP MANAGER] old instance shutdown during reinit failed: ${errorMessage(err)}`,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// Force the idempotence check in initializeLspServerManager() to fall
|
||||
// through. Generation counter handles invalidating any in-flight init.
|
||||
lspManagerInstance = undefined
|
||||
initializationState = 'not-started'
|
||||
initializationError = undefined
|
||||
|
||||
initializeLspServerManager()
|
||||
}
|
||||
|
||||
/**
|
||||
* Shutdown the LSP server manager and clean up resources.
|
||||
*
|
||||
* This should be called during Claude Code shutdown. Stops all running LSP servers
|
||||
* and clears internal state. Safe to call when not initialized (no-op).
|
||||
*
|
||||
* NOTE: Errors during shutdown are logged for monitoring but NOT propagated to the caller.
|
||||
* State is always cleared even if shutdown fails, to prevent resource accumulation.
|
||||
* This is acceptable during application exit when recovery is not possible.
|
||||
*
|
||||
* @returns Promise that resolves when shutdown completes (errors are swallowed)
|
||||
*/
|
||||
export async function shutdownLspServerManager(): Promise<void> {
|
||||
if (lspManagerInstance === undefined) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
await lspManagerInstance.shutdown()
|
||||
logForDebugging('LSP server manager shut down successfully')
|
||||
} catch (error: unknown) {
|
||||
logError(error as Error)
|
||||
logForDebugging(
|
||||
`Failed to shutdown LSP server manager: ${errorMessage(error)}`,
|
||||
)
|
||||
} finally {
|
||||
// Always clear state even if shutdown failed
|
||||
lspManagerInstance = undefined
|
||||
initializationState = 'not-started'
|
||||
initializationError = undefined
|
||||
initializationPromise = undefined
|
||||
// Increment generation to invalidate any pending initializations
|
||||
initializationGeneration++
|
||||
}
|
||||
}
|
||||
328
src/services/lsp/passiveFeedback.ts
Normal file
328
src/services/lsp/passiveFeedback.ts
Normal file
@@ -0,0 +1,328 @@
|
||||
import { fileURLToPath } from 'url'
|
||||
import type { PublishDiagnosticsParams } from 'vscode-languageserver-protocol'
|
||||
import { logForDebugging } from '../../utils/debug.js'
|
||||
import { toError } from '../../utils/errors.js'
|
||||
import { logError } from '../../utils/log.js'
|
||||
import { jsonStringify } from '../../utils/slowOperations.js'
|
||||
import type { DiagnosticFile } from '../diagnosticTracking.js'
|
||||
import { registerPendingLSPDiagnostic } from './LSPDiagnosticRegistry.js'
|
||||
import type { LSPServerManager } from './LSPServerManager.js'
|
||||
|
||||
/**
|
||||
* Map LSP severity to Claude diagnostic severity
|
||||
*
|
||||
* Maps LSP severity numbers to Claude diagnostic severity strings.
|
||||
* Accepts numeric severity values (1=Error, 2=Warning, 3=Information, 4=Hint)
|
||||
* or undefined, defaulting to 'Error' for invalid/missing values.
|
||||
*/
|
||||
function mapLSPSeverity(
|
||||
lspSeverity: number | undefined,
|
||||
): 'Error' | 'Warning' | 'Info' | 'Hint' {
|
||||
// LSP DiagnosticSeverity enum:
|
||||
// 1 = Error, 2 = Warning, 3 = Information, 4 = Hint
|
||||
switch (lspSeverity) {
|
||||
case 1:
|
||||
return 'Error'
|
||||
case 2:
|
||||
return 'Warning'
|
||||
case 3:
|
||||
return 'Info'
|
||||
case 4:
|
||||
return 'Hint'
|
||||
default:
|
||||
return 'Error'
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert LSP diagnostics to Claude diagnostic format
|
||||
*
|
||||
* Converts LSP PublishDiagnosticsParams to DiagnosticFile[] format
|
||||
* used by Claude's attachment system.
|
||||
*/
|
||||
export function formatDiagnosticsForAttachment(
|
||||
params: PublishDiagnosticsParams,
|
||||
): DiagnosticFile[] {
|
||||
// Parse URI (may be file:// or plain path) and normalize to file system path
|
||||
let uri: string
|
||||
try {
|
||||
// Handle both file:// URIs and plain paths
|
||||
uri = params.uri.startsWith('file://')
|
||||
? fileURLToPath(params.uri)
|
||||
: params.uri
|
||||
} catch (error) {
|
||||
const err = toError(error)
|
||||
logError(err)
|
||||
logForDebugging(
|
||||
`Failed to convert URI to file path: ${params.uri}. Error: ${err.message}. Using original URI as fallback.`,
|
||||
)
|
||||
// Gracefully fallback to original URI - LSP servers may send malformed URIs
|
||||
uri = params.uri
|
||||
}
|
||||
|
||||
const diagnostics = params.diagnostics.map(
|
||||
(diag: {
|
||||
message: string
|
||||
severity?: number
|
||||
range: {
|
||||
start: { line: number; character: number }
|
||||
end: { line: number; character: number }
|
||||
}
|
||||
source?: string
|
||||
code?: string | number
|
||||
}) => ({
|
||||
message: diag.message,
|
||||
severity: mapLSPSeverity(diag.severity),
|
||||
range: {
|
||||
start: {
|
||||
line: diag.range.start.line,
|
||||
character: diag.range.start.character,
|
||||
},
|
||||
end: {
|
||||
line: diag.range.end.line,
|
||||
character: diag.range.end.character,
|
||||
},
|
||||
},
|
||||
source: diag.source,
|
||||
code:
|
||||
diag.code !== undefined && diag.code !== null
|
||||
? String(diag.code)
|
||||
: undefined,
|
||||
}),
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
uri,
|
||||
diagnostics,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
/**
|
||||
* Handler registration result with tracking data
|
||||
*/
|
||||
export type HandlerRegistrationResult = {
|
||||
/** Total number of servers */
|
||||
totalServers: number
|
||||
/** Number of successful registrations */
|
||||
successCount: number
|
||||
/** Registration errors per server */
|
||||
registrationErrors: Array<{ serverName: string; error: string }>
|
||||
/** Runtime failure tracking (shared across all handler invocations) */
|
||||
diagnosticFailures: Map<string, { count: number; lastError: string }>
|
||||
}
|
||||
|
||||
/**
|
||||
* Register LSP notification handlers on all servers
|
||||
*
|
||||
* Sets up handlers to listen for textDocument/publishDiagnostics notifications
|
||||
* from all LSP servers and routes them to Claude's diagnostic system.
|
||||
* Uses public getAllServers() API for clean access to server instances.
|
||||
*
|
||||
* @returns Tracking data for registration status and runtime failures
|
||||
*/
|
||||
export function registerLSPNotificationHandlers(
|
||||
manager: LSPServerManager,
|
||||
): HandlerRegistrationResult {
|
||||
// Register handlers on all configured servers to capture diagnostics from any language
|
||||
const servers = manager.getAllServers()
|
||||
|
||||
// Track partial failures - allow successful server registrations even if some fail
|
||||
const registrationErrors: Array<{ serverName: string; error: string }> = []
|
||||
let successCount = 0
|
||||
|
||||
// Track consecutive failures per server to warn users after 3+ failures
|
||||
const diagnosticFailures: Map<string, { count: number; lastError: string }> =
|
||||
new Map()
|
||||
|
||||
for (const [serverName, serverInstance] of servers.entries()) {
|
||||
try {
|
||||
// Validate server instance has onNotification method
|
||||
if (
|
||||
!serverInstance ||
|
||||
typeof serverInstance.onNotification !== 'function'
|
||||
) {
|
||||
const errorMsg = !serverInstance
|
||||
? 'Server instance is null/undefined'
|
||||
: 'Server instance has no onNotification method'
|
||||
|
||||
registrationErrors.push({ serverName, error: errorMsg })
|
||||
|
||||
const err = new Error(`${errorMsg} for ${serverName}`)
|
||||
logError(err)
|
||||
logForDebugging(
|
||||
`Skipping handler registration for ${serverName}: ${errorMsg}`,
|
||||
)
|
||||
continue // Skip this server but track the failure
|
||||
}
|
||||
|
||||
// Errors are isolated to avoid breaking other servers
|
||||
serverInstance.onNotification(
|
||||
'textDocument/publishDiagnostics',
|
||||
(params: unknown) => {
|
||||
logForDebugging(
|
||||
`[PASSIVE DIAGNOSTICS] Handler invoked for ${serverName}! Params type: ${typeof params}`,
|
||||
)
|
||||
try {
|
||||
// Validate params structure before casting
|
||||
if (
|
||||
!params ||
|
||||
typeof params !== 'object' ||
|
||||
!('uri' in params) ||
|
||||
!('diagnostics' in params)
|
||||
) {
|
||||
const err = new Error(
|
||||
`LSP server ${serverName} sent invalid diagnostic params (missing uri or diagnostics)`,
|
||||
)
|
||||
logError(err)
|
||||
logForDebugging(
|
||||
`Invalid diagnostic params from ${serverName}: ${jsonStringify(params)}`,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
const diagnosticParams = params as PublishDiagnosticsParams
|
||||
logForDebugging(
|
||||
`Received diagnostics from ${serverName}: ${diagnosticParams.diagnostics.length} diagnostic(s) for ${diagnosticParams.uri}`,
|
||||
)
|
||||
|
||||
// Convert LSP diagnostics to Claude format (can throw on invalid URIs)
|
||||
const diagnosticFiles =
|
||||
formatDiagnosticsForAttachment(diagnosticParams)
|
||||
|
||||
// Only send notification if there are diagnostics
|
||||
const firstFile = diagnosticFiles[0]
|
||||
if (
|
||||
!firstFile ||
|
||||
diagnosticFiles.length === 0 ||
|
||||
firstFile.diagnostics.length === 0
|
||||
) {
|
||||
logForDebugging(
|
||||
`Skipping empty diagnostics from ${serverName} for ${diagnosticParams.uri}`,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Register diagnostics for async delivery via attachment system
|
||||
// Follows same pattern as AsyncHookRegistry for consistent async attachment delivery
|
||||
try {
|
||||
registerPendingLSPDiagnostic({
|
||||
serverName,
|
||||
files: diagnosticFiles,
|
||||
})
|
||||
|
||||
logForDebugging(
|
||||
`LSP Diagnostics: Registered ${diagnosticFiles.length} diagnostic file(s) from ${serverName} for async delivery`,
|
||||
)
|
||||
|
||||
// Success - reset failure counter for this server
|
||||
diagnosticFailures.delete(serverName)
|
||||
} catch (error) {
|
||||
const err = toError(error)
|
||||
logError(err)
|
||||
logForDebugging(
|
||||
`Error registering LSP diagnostics from ${serverName}: ` +
|
||||
`URI: ${diagnosticParams.uri}, ` +
|
||||
`Diagnostic count: ${firstFile.diagnostics.length}, ` +
|
||||
`Error: ${err.message}`,
|
||||
)
|
||||
|
||||
// Track consecutive failures and warn after 3+
|
||||
const failures = diagnosticFailures.get(serverName) || {
|
||||
count: 0,
|
||||
lastError: '',
|
||||
}
|
||||
failures.count++
|
||||
failures.lastError = err.message
|
||||
diagnosticFailures.set(serverName, failures)
|
||||
|
||||
if (failures.count >= 3) {
|
||||
logForDebugging(
|
||||
`WARNING: LSP diagnostic handler for ${serverName} has failed ${failures.count} times consecutively. ` +
|
||||
`Last error: ${failures.lastError}. ` +
|
||||
`This may indicate a problem with the LSP server or diagnostic processing. ` +
|
||||
`Check logs for details.`,
|
||||
)
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
// Catch any unexpected errors from the entire handler to prevent breaking the notification loop
|
||||
const err = toError(error)
|
||||
logError(err)
|
||||
logForDebugging(
|
||||
`Unexpected error processing diagnostics from ${serverName}: ${err.message}`,
|
||||
)
|
||||
|
||||
// Track consecutive failures and warn after 3+
|
||||
const failures = diagnosticFailures.get(serverName) || {
|
||||
count: 0,
|
||||
lastError: '',
|
||||
}
|
||||
failures.count++
|
||||
failures.lastError = err.message
|
||||
diagnosticFailures.set(serverName, failures)
|
||||
|
||||
if (failures.count >= 3) {
|
||||
logForDebugging(
|
||||
`WARNING: LSP diagnostic handler for ${serverName} has failed ${failures.count} times consecutively. ` +
|
||||
`Last error: ${failures.lastError}. ` +
|
||||
`This may indicate a problem with the LSP server or diagnostic processing. ` +
|
||||
`Check logs for details.`,
|
||||
)
|
||||
}
|
||||
|
||||
// Don't re-throw - isolate errors to this server only
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
logForDebugging(`Registered diagnostics handler for ${serverName}`)
|
||||
successCount++
|
||||
} catch (error) {
|
||||
const err = toError(error)
|
||||
|
||||
registrationErrors.push({
|
||||
serverName,
|
||||
error: err.message,
|
||||
})
|
||||
|
||||
logError(err)
|
||||
logForDebugging(
|
||||
`Failed to register diagnostics handler for ${serverName}: ` +
|
||||
`Error: ${err.message}`,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Report overall registration status
|
||||
const totalServers = servers.size
|
||||
if (registrationErrors.length > 0) {
|
||||
const failedServers = registrationErrors
|
||||
.map(e => `${e.serverName} (${e.error})`)
|
||||
.join(', ')
|
||||
// Log aggregate failures for tracking
|
||||
logError(
|
||||
new Error(
|
||||
`Failed to register diagnostics for ${registrationErrors.length} LSP server(s): ${failedServers}`,
|
||||
),
|
||||
)
|
||||
logForDebugging(
|
||||
`LSP notification handler registration: ${successCount}/${totalServers} succeeded. ` +
|
||||
`Failed servers: ${failedServers}. ` +
|
||||
`Diagnostics from failed servers will not be delivered.`,
|
||||
)
|
||||
} else {
|
||||
logForDebugging(
|
||||
`LSP notification handlers registered successfully for all ${totalServers} server(s)`,
|
||||
)
|
||||
}
|
||||
|
||||
// Return tracking data for monitoring and testing
|
||||
return {
|
||||
totalServers,
|
||||
successCount,
|
||||
registrationErrors,
|
||||
diagnosticFailures,
|
||||
}
|
||||
}
|
||||
63
src/services/mcp/InProcessTransport.ts
Normal file
63
src/services/mcp/InProcessTransport.ts
Normal file
@@ -0,0 +1,63 @@
|
||||
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'
|
||||
import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js'
|
||||
|
||||
/**
|
||||
* In-process linked transport pair for running an MCP server and client
|
||||
* in the same process without spawning a subprocess.
|
||||
*
|
||||
* `send()` on one side delivers to `onmessage` on the other.
|
||||
* `close()` on either side calls `onclose` on both.
|
||||
*/
|
||||
class InProcessTransport implements Transport {
|
||||
private peer: InProcessTransport | undefined
|
||||
private closed = false
|
||||
|
||||
onclose?: () => void
|
||||
onerror?: (error: Error) => void
|
||||
onmessage?: (message: JSONRPCMessage) => void
|
||||
|
||||
/** @internal */
|
||||
_setPeer(peer: InProcessTransport): void {
|
||||
this.peer = peer
|
||||
}
|
||||
|
||||
async start(): Promise<void> {}
|
||||
|
||||
async send(message: JSONRPCMessage): Promise<void> {
|
||||
if (this.closed) {
|
||||
throw new Error('Transport is closed')
|
||||
}
|
||||
// Deliver to the other side asynchronously to avoid stack depth issues
|
||||
// with synchronous request/response cycles
|
||||
queueMicrotask(() => {
|
||||
this.peer?.onmessage?.(message)
|
||||
})
|
||||
}
|
||||
|
||||
async close(): Promise<void> {
|
||||
if (this.closed) {
|
||||
return
|
||||
}
|
||||
this.closed = true
|
||||
this.onclose?.()
|
||||
// Close the peer if it hasn't already closed
|
||||
if (this.peer && !this.peer.closed) {
|
||||
this.peer.closed = true
|
||||
this.peer.onclose?.()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a pair of linked transports for in-process MCP communication.
|
||||
* Messages sent on one transport are delivered to the other's `onmessage`.
|
||||
*
|
||||
* @returns [clientTransport, serverTransport]
|
||||
*/
|
||||
export function createLinkedTransportPair(): [Transport, Transport] {
|
||||
const a = new InProcessTransport()
|
||||
const b = new InProcessTransport()
|
||||
a._setPeer(b)
|
||||
b._setPeer(a)
|
||||
return [a, b]
|
||||
}
|
||||
73
src/services/mcp/MCPConnectionManager.tsx
Normal file
73
src/services/mcp/MCPConnectionManager.tsx
Normal file
File diff suppressed because one or more lines are too long
136
src/services/mcp/SdkControlTransport.ts
Normal file
136
src/services/mcp/SdkControlTransport.ts
Normal file
@@ -0,0 +1,136 @@
|
||||
/**
|
||||
* SDK MCP Transport Bridge
|
||||
*
|
||||
* This file implements a transport bridge that allows MCP servers running in the SDK process
|
||||
* to communicate with the Claude Code CLI process through control messages.
|
||||
*
|
||||
* ## Architecture Overview
|
||||
*
|
||||
* Unlike regular MCP servers that run as separate processes, SDK MCP servers run in-process
|
||||
* within the SDK. This requires a special transport mechanism to bridge communication between:
|
||||
* - The CLI process (where the MCP client runs)
|
||||
* - The SDK process (where the SDK MCP server runs)
|
||||
*
|
||||
* ## Message Flow
|
||||
*
|
||||
* ### CLI → SDK (via SdkControlClientTransport)
|
||||
* 1. CLI's MCP Client calls a tool → sends JSONRPC request to SdkControlClientTransport
|
||||
* 2. Transport wraps the message in a control request with server_name and request_id
|
||||
* 3. Control request is sent via stdout to the SDK process
|
||||
* 4. SDK's StructuredIO receives the control response and routes it back to the transport
|
||||
* 5. Transport unwraps the response and returns it to the MCP Client
|
||||
*
|
||||
* ### SDK → CLI (via SdkControlServerTransport)
|
||||
* 1. Query receives control request with MCP message and calls transport.onmessage
|
||||
* 2. MCP server processes the message and calls transport.send() with response
|
||||
* 3. Transport calls sendMcpMessage callback with the response
|
||||
* 4. Query's callback resolves the pending promise with the response
|
||||
* 5. Query returns the response to complete the control request
|
||||
*
|
||||
* ## Key Design Points
|
||||
*
|
||||
* - SdkControlClientTransport: StructuredIO tracks pending requests
|
||||
* - SdkControlServerTransport: Query tracks pending requests
|
||||
* - The control request wrapper includes server_name to route to the correct SDK server
|
||||
* - The system supports multiple SDK MCP servers running simultaneously
|
||||
* - Message IDs are preserved through the entire flow for proper correlation
|
||||
*/
|
||||
|
||||
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'
|
||||
import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js'
|
||||
|
||||
/**
|
||||
* Callback function to send an MCP message and get the response
|
||||
*/
|
||||
export type SendMcpMessageCallback = (
|
||||
serverName: string,
|
||||
message: JSONRPCMessage,
|
||||
) => Promise<JSONRPCMessage>
|
||||
|
||||
/**
|
||||
* CLI-side transport for SDK MCP servers.
|
||||
*
|
||||
* This transport is used in the CLI process to bridge communication between:
|
||||
* - The CLI's MCP Client (which wants to call tools on SDK MCP servers)
|
||||
* - The SDK process (where the actual MCP server runs)
|
||||
*
|
||||
* It converts MCP protocol messages into control requests that can be sent
|
||||
* through stdout/stdin to the SDK process.
|
||||
*/
|
||||
export class SdkControlClientTransport implements Transport {
|
||||
private isClosed = false
|
||||
|
||||
onclose?: () => void
|
||||
onerror?: (error: Error) => void
|
||||
onmessage?: (message: JSONRPCMessage) => void
|
||||
|
||||
constructor(
|
||||
private serverName: string,
|
||||
private sendMcpMessage: SendMcpMessageCallback,
|
||||
) {}
|
||||
|
||||
async start(): Promise<void> {}
|
||||
|
||||
async send(message: JSONRPCMessage): Promise<void> {
|
||||
if (this.isClosed) {
|
||||
throw new Error('Transport is closed')
|
||||
}
|
||||
|
||||
// Send the message and wait for the response
|
||||
const response = await this.sendMcpMessage(this.serverName, message)
|
||||
|
||||
// Pass the response back to the MCP client
|
||||
if (this.onmessage) {
|
||||
this.onmessage(response)
|
||||
}
|
||||
}
|
||||
|
||||
async close(): Promise<void> {
|
||||
if (this.isClosed) {
|
||||
return
|
||||
}
|
||||
this.isClosed = true
|
||||
this.onclose?.()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* SDK-side transport for SDK MCP servers.
|
||||
*
|
||||
* This transport is used in the SDK process to bridge communication between:
|
||||
* - Control requests coming from the CLI (via stdin)
|
||||
* - The actual MCP server running in the SDK process
|
||||
*
|
||||
* It acts as a simple pass-through that forwards messages to the MCP server
|
||||
* and sends responses back via a callback.
|
||||
*
|
||||
* Note: Query handles all request/response correlation and async flow.
|
||||
*/
|
||||
export class SdkControlServerTransport implements Transport {
|
||||
private isClosed = false
|
||||
|
||||
constructor(private sendMcpMessage: (message: JSONRPCMessage) => void) {}
|
||||
|
||||
onclose?: () => void
|
||||
onerror?: (error: Error) => void
|
||||
onmessage?: (message: JSONRPCMessage) => void
|
||||
|
||||
async start(): Promise<void> {}
|
||||
|
||||
async send(message: JSONRPCMessage): Promise<void> {
|
||||
if (this.isClosed) {
|
||||
throw new Error('Transport is closed')
|
||||
}
|
||||
|
||||
// Simply pass the response back through the callback
|
||||
this.sendMcpMessage(message)
|
||||
}
|
||||
|
||||
async close(): Promise<void> {
|
||||
if (this.isClosed) {
|
||||
return
|
||||
}
|
||||
this.isClosed = true
|
||||
this.onclose?.()
|
||||
}
|
||||
}
|
||||
2465
src/services/mcp/auth.ts
Normal file
2465
src/services/mcp/auth.ts
Normal file
File diff suppressed because it is too large
Load Diff
76
src/services/mcp/channelAllowlist.ts
Normal file
76
src/services/mcp/channelAllowlist.ts
Normal file
@@ -0,0 +1,76 @@
|
||||
/**
|
||||
* Approved channel plugins allowlist. --channels plugin:name@marketplace
|
||||
* entries only register if {marketplace, plugin} is on this list. server:
|
||||
* entries always fail (schema is plugin-only). The
|
||||
* --dangerously-load-development-channels flag bypasses for both kinds.
|
||||
* Lives in GrowthBook so it can be updated without a release.
|
||||
*
|
||||
* Plugin-level granularity: if a plugin is approved, all its channel
|
||||
* servers are. Per-server gating was overengineering — a plugin that
|
||||
* sprouts a malicious second server is already compromised, and per-server
|
||||
* entries would break on harmless plugin refactors.
|
||||
*
|
||||
* The allowlist check is a pure {marketplace, plugin} comparison against
|
||||
* the user's typed tag. The gate's separate 'marketplace' step verifies
|
||||
* the tag matches what's actually installed before this check runs.
|
||||
*/
|
||||
|
||||
import { z } from 'zod/v4'
|
||||
import { lazySchema } from '../../utils/lazySchema.js'
|
||||
import { parsePluginIdentifier } from '../../utils/plugins/pluginIdentifier.js'
|
||||
import { getFeatureValue_CACHED_MAY_BE_STALE } from '../analytics/growthbook.js'
|
||||
|
||||
export type ChannelAllowlistEntry = {
|
||||
marketplace: string
|
||||
plugin: string
|
||||
}
|
||||
|
||||
const ChannelAllowlistSchema = lazySchema(() =>
|
||||
z.array(
|
||||
z.object({
|
||||
marketplace: z.string(),
|
||||
plugin: z.string(),
|
||||
}),
|
||||
),
|
||||
)
|
||||
|
||||
export function getChannelAllowlist(): ChannelAllowlistEntry[] {
|
||||
const raw = getFeatureValue_CACHED_MAY_BE_STALE<unknown>(
|
||||
'tengu_harbor_ledger',
|
||||
[],
|
||||
)
|
||||
const parsed = ChannelAllowlistSchema().safeParse(raw)
|
||||
return parsed.success ? parsed.data : []
|
||||
}
|
||||
|
||||
/**
|
||||
* Overall channels on/off. Checked before any per-server gating —
|
||||
* when false, --channels is a no-op and no handlers register.
|
||||
* Default false; GrowthBook 5-min refresh.
|
||||
*/
|
||||
export function isChannelsEnabled(): boolean {
|
||||
return getFeatureValue_CACHED_MAY_BE_STALE('tengu_harbor', false)
|
||||
}
|
||||
|
||||
/**
|
||||
* Pure allowlist check keyed off the connection's pluginSource — for UI
|
||||
* pre-filtering so the IDE only shows "Enable channel?" for servers that will
|
||||
* actually pass the gate. Not a security boundary: channel_enable still runs
|
||||
* the full gate. Matches the allowlist comparison inside gateChannelServer()
|
||||
* but standalone (no session/marketplace coupling — those are tautologies
|
||||
* when the entry is derived from pluginSource).
|
||||
*
|
||||
* Returns false for undefined pluginSource (non-plugin server — can never
|
||||
* match the {marketplace, plugin}-keyed ledger) and for @-less sources
|
||||
* (builtin/inline — same reason).
|
||||
*/
|
||||
export function isChannelAllowlisted(
|
||||
pluginSource: string | undefined,
|
||||
): boolean {
|
||||
if (!pluginSource) return false
|
||||
const { name, marketplace } = parsePluginIdentifier(pluginSource)
|
||||
if (!marketplace) return false
|
||||
return getChannelAllowlist().some(
|
||||
e => e.plugin === name && e.marketplace === marketplace,
|
||||
)
|
||||
}
|
||||
316
src/services/mcp/channelNotification.ts
Normal file
316
src/services/mcp/channelNotification.ts
Normal file
@@ -0,0 +1,316 @@
|
||||
/**
|
||||
* Channel notifications — lets an MCP server push user messages into the
|
||||
* conversation. A "channel" (Discord, Slack, SMS, etc.) is just an MCP server
|
||||
* that:
|
||||
* - exposes tools for outbound messages (e.g. `send_message`) — standard MCP
|
||||
* - sends `notifications/claude/channel` notifications for inbound — this file
|
||||
*
|
||||
* The notification handler wraps the content in a <channel> tag and
|
||||
* enqueues it. SleepTool polls hasCommandsInQueue() and wakes within 1s.
|
||||
* The model sees where the message came from and decides which tool to reply
|
||||
* with (the channel's MCP tool, SendUserMessage, or both).
|
||||
*
|
||||
* feature('KAIROS') || feature('KAIROS_CHANNELS'). Runtime gate tengu_harbor.
|
||||
* Requires claude.ai OAuth auth — API key users are blocked until
|
||||
* console gets a channelsEnabled admin surface. Teams/Enterprise orgs
|
||||
* must explicitly opt in via channelsEnabled: true in managed settings.
|
||||
*/
|
||||
|
||||
import type { ServerCapabilities } from '@modelcontextprotocol/sdk/types.js'
|
||||
import { z } from 'zod/v4'
|
||||
import { type ChannelEntry, getAllowedChannels } from '../../bootstrap/state.js'
|
||||
import { CHANNEL_TAG } from '../../constants/xml.js'
|
||||
import {
|
||||
getClaudeAIOAuthTokens,
|
||||
getSubscriptionType,
|
||||
} from '../../utils/auth.js'
|
||||
import { lazySchema } from '../../utils/lazySchema.js'
|
||||
import { parsePluginIdentifier } from '../../utils/plugins/pluginIdentifier.js'
|
||||
import { getSettingsForSource } from '../../utils/settings/settings.js'
|
||||
import { escapeXmlAttr } from '../../utils/xml.js'
|
||||
import {
|
||||
type ChannelAllowlistEntry,
|
||||
getChannelAllowlist,
|
||||
isChannelsEnabled,
|
||||
} from './channelAllowlist.js'
|
||||
|
||||
export const ChannelMessageNotificationSchema = lazySchema(() =>
|
||||
z.object({
|
||||
method: z.literal('notifications/claude/channel'),
|
||||
params: z.object({
|
||||
content: z.string(),
|
||||
// Opaque passthrough — thread_id, user, whatever the channel wants the
|
||||
// model to see. Rendered as attributes on the <channel> tag.
|
||||
meta: z.record(z.string(), z.string()).optional(),
|
||||
}),
|
||||
}),
|
||||
)
|
||||
|
||||
/**
|
||||
* Structured permission reply from a channel server. Servers that support
|
||||
* this declare `capabilities.experimental['claude/channel/permission']` and
|
||||
* emit this event INSTEAD of relaying "yes tbxkq" as text via
|
||||
* notifications/claude/channel. Explicit opt-in per server — a channel that
|
||||
* just wants to relay text never becomes a permission surface by accident.
|
||||
*
|
||||
* The server parses the user's reply (spec: /^\s*(y|yes|n|no)\s+([a-km-z]{5})\s*$/i)
|
||||
* and emits {request_id, behavior}. CC matches request_id against its
|
||||
* pending map. Unlike the regex-intercept approach, text in the general
|
||||
* channel can never accidentally match — approval requires the server
|
||||
* to deliberately emit this specific event.
|
||||
*/
|
||||
export const CHANNEL_PERMISSION_METHOD =
|
||||
'notifications/claude/channel/permission'
|
||||
export const ChannelPermissionNotificationSchema = lazySchema(() =>
|
||||
z.object({
|
||||
method: z.literal(CHANNEL_PERMISSION_METHOD),
|
||||
params: z.object({
|
||||
request_id: z.string(),
|
||||
behavior: z.enum(['allow', 'deny']),
|
||||
}),
|
||||
}),
|
||||
)
|
||||
|
||||
/**
|
||||
* Outbound: CC → server. Fired from interactiveHandler.ts when a
|
||||
* permission dialog opens and the server has declared the permission
|
||||
* capability. Server formats the message for its platform (Telegram
|
||||
* markdown, iMessage rich text, Discord embed) and sends it to the
|
||||
* human. When the human replies "yes tbxkq", the server parses that
|
||||
* against PERMISSION_REPLY_RE and emits the inbound schema above.
|
||||
*
|
||||
* Not a zod schema — CC SENDS this, doesn't validate it. A type here
|
||||
* keeps both halves of the protocol documented side by side.
|
||||
*/
|
||||
export const CHANNEL_PERMISSION_REQUEST_METHOD =
|
||||
'notifications/claude/channel/permission_request'
|
||||
export type ChannelPermissionRequestParams = {
|
||||
request_id: string
|
||||
tool_name: string
|
||||
description: string
|
||||
/** JSON-stringified tool input, truncated to 200 chars with …. Full
|
||||
* input is in the local terminal dialog; this is a phone-sized
|
||||
* preview. Server decides whether/how to show it. */
|
||||
input_preview: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Meta keys become XML attribute NAMES — a crafted key like
|
||||
* `x="" injected="y` would break out of the attribute structure. Only
|
||||
* accept keys that look like plain identifiers. This is stricter than
|
||||
* the XML spec (which allows `:`, `.`, `-`) but channel servers only
|
||||
* send `chat_id`, `user`, `thread_ts`, `message_id` in practice.
|
||||
*/
|
||||
const SAFE_META_KEY = /^[a-zA-Z_][a-zA-Z0-9_]*$/
|
||||
|
||||
export function wrapChannelMessage(
|
||||
serverName: string,
|
||||
content: string,
|
||||
meta?: Record<string, string>,
|
||||
): string {
|
||||
const attrs = Object.entries(meta ?? {})
|
||||
.filter(([k]) => SAFE_META_KEY.test(k))
|
||||
.map(([k, v]) => ` ${k}="${escapeXmlAttr(v)}"`)
|
||||
.join('')
|
||||
return `<${CHANNEL_TAG} source="${escapeXmlAttr(serverName)}"${attrs}>\n${content}\n</${CHANNEL_TAG}>`
|
||||
}
|
||||
|
||||
/**
|
||||
* Effective allowlist for the current session. Team/enterprise orgs can set
|
||||
* allowedChannelPlugins in managed settings — when set, it REPLACES the
|
||||
* GrowthBook ledger (admin owns the trust decision). Undefined falls back
|
||||
* to the ledger. Unmanaged users always get the ledger.
|
||||
*
|
||||
* Callers already read sub/policy for the policy gate — pass them in to
|
||||
* avoid double-reading getSettingsForSource (uncached).
|
||||
*/
|
||||
export function getEffectiveChannelAllowlist(
|
||||
sub: ReturnType<typeof getSubscriptionType>,
|
||||
orgList: ChannelAllowlistEntry[] | undefined,
|
||||
): {
|
||||
entries: ChannelAllowlistEntry[]
|
||||
source: 'org' | 'ledger'
|
||||
} {
|
||||
if ((sub === 'team' || sub === 'enterprise') && orgList) {
|
||||
return { entries: orgList, source: 'org' }
|
||||
}
|
||||
return { entries: getChannelAllowlist(), source: 'ledger' }
|
||||
}
|
||||
|
||||
export type ChannelGateResult =
|
||||
| { action: 'register' }
|
||||
| {
|
||||
action: 'skip'
|
||||
kind:
|
||||
| 'capability'
|
||||
| 'disabled'
|
||||
| 'auth'
|
||||
| 'policy'
|
||||
| 'session'
|
||||
| 'marketplace'
|
||||
| 'allowlist'
|
||||
reason: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Match a connected MCP server against the user's parsed --channels entries.
|
||||
* server-kind is exact match on bare name; plugin-kind matches on the second
|
||||
* segment of plugin:X:Y. Returns the matching entry so callers can read its
|
||||
* kind — that's the user's trust declaration, not inferred from runtime shape.
|
||||
*/
|
||||
export function findChannelEntry(
|
||||
serverName: string,
|
||||
channels: readonly ChannelEntry[],
|
||||
): ChannelEntry | undefined {
|
||||
// split unconditionally — for a bare name like 'slack', parts is ['slack']
|
||||
// and the plugin-kind branch correctly never matches (parts[0] !== 'plugin').
|
||||
const parts = serverName.split(':')
|
||||
return channels.find(c =>
|
||||
c.kind === 'server'
|
||||
? serverName === c.name
|
||||
: parts[0] === 'plugin' && parts[1] === c.name,
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Gate an MCP server's channel-notification path. Caller checks
|
||||
* feature('KAIROS') || feature('KAIROS_CHANNELS') first (build-time
|
||||
* elimination). Gate order: capability → runtime gate (tengu_harbor) →
|
||||
* auth (OAuth only) → org policy → session --channels → allowlist.
|
||||
* API key users are blocked at the auth layer — channels requires
|
||||
* claude.ai auth; console orgs have no admin opt-in surface yet.
|
||||
*
|
||||
* skip Not a channel server, or managed org hasn't opted in, or
|
||||
* not in session --channels. Connection stays up; handler
|
||||
* not registered.
|
||||
* register Subscribe to notifications/claude/channel.
|
||||
*
|
||||
* Which servers can connect at all is governed by allowedMcpServers —
|
||||
* this gate only decides whether the notification handler registers.
|
||||
*/
|
||||
export function gateChannelServer(
|
||||
serverName: string,
|
||||
capabilities: ServerCapabilities | undefined,
|
||||
pluginSource: string | undefined,
|
||||
): ChannelGateResult {
|
||||
// Channel servers declare `experimental['claude/channel']: {}` (MCP's
|
||||
// presence-signal idiom — same as `tools: {}`). Truthy covers `{}` and
|
||||
// `true`; absent/undefined/explicit-`false` all fail. Key matches the
|
||||
// notification method namespace (notifications/claude/channel).
|
||||
if (!capabilities?.experimental?.['claude/channel']) {
|
||||
return {
|
||||
action: 'skip',
|
||||
kind: 'capability',
|
||||
reason: 'server did not declare claude/channel capability',
|
||||
}
|
||||
}
|
||||
|
||||
// Overall runtime gate. After capability so normal MCP servers never hit
|
||||
// this path. Before auth/policy so the killswitch works regardless of
|
||||
// session state.
|
||||
if (!isChannelsEnabled()) {
|
||||
return {
|
||||
action: 'skip',
|
||||
kind: 'disabled',
|
||||
reason: 'channels feature is not currently available',
|
||||
}
|
||||
}
|
||||
|
||||
// OAuth-only. API key users (console) are blocked — there's no
|
||||
// channelsEnabled admin surface in console yet, so the policy opt-in
|
||||
// flow doesn't exist for them. Drop this when console parity lands.
|
||||
if (!getClaudeAIOAuthTokens()?.accessToken) {
|
||||
return {
|
||||
action: 'skip',
|
||||
kind: 'auth',
|
||||
reason: 'channels requires claude.ai authentication (run /login)',
|
||||
}
|
||||
}
|
||||
|
||||
// Teams/Enterprise opt-in. Managed orgs must explicitly enable channels.
|
||||
// Default OFF — absent or false blocks. Keyed off subscription tier, not
|
||||
// "policy settings exist" — a team org with zero configured policy keys
|
||||
// (remote endpoint returns 404) is still a managed org and must not fall
|
||||
// through to the unmanaged path.
|
||||
const sub = getSubscriptionType()
|
||||
const managed = sub === 'team' || sub === 'enterprise'
|
||||
const policy = managed ? getSettingsForSource('policySettings') : undefined
|
||||
if (managed && policy?.channelsEnabled !== true) {
|
||||
return {
|
||||
action: 'skip',
|
||||
kind: 'policy',
|
||||
reason:
|
||||
'channels not enabled by org policy (set channelsEnabled: true in managed settings)',
|
||||
}
|
||||
}
|
||||
|
||||
// User-level session opt-in. A server must be explicitly listed in
|
||||
// --channels to push inbound this session — protects against a trusted
|
||||
// server surprise-adding the capability.
|
||||
const entry = findChannelEntry(serverName, getAllowedChannels())
|
||||
if (!entry) {
|
||||
return {
|
||||
action: 'skip',
|
||||
kind: 'session',
|
||||
reason: `server ${serverName} not in --channels list for this session`,
|
||||
}
|
||||
}
|
||||
|
||||
if (entry.kind === 'plugin') {
|
||||
// Marketplace verification: the tag is intent (plugin:slack@anthropic),
|
||||
// the runtime name is just plugin:slack:X — could be slack@anthropic or
|
||||
// slack@evil depending on what's installed. Verify they match before
|
||||
// trusting the tag for the allowlist check below. Source is stashed on
|
||||
// the config at addPluginScopeToServers — undefined (non-plugin server,
|
||||
// shouldn't happen for plugin-kind entry) or @-less (builtin/inline)
|
||||
// both fail the comparison.
|
||||
const actual = pluginSource
|
||||
? parsePluginIdentifier(pluginSource).marketplace
|
||||
: undefined
|
||||
if (actual !== entry.marketplace) {
|
||||
return {
|
||||
action: 'skip',
|
||||
kind: 'marketplace',
|
||||
reason: `you asked for plugin:${entry.name}@${entry.marketplace} but the installed ${entry.name} plugin is from ${actual ?? 'an unknown source'}`,
|
||||
}
|
||||
}
|
||||
|
||||
// Approved-plugin allowlist. Marketplace gate already verified
|
||||
// tag == reality, so this is a pure entry check. entry.dev (per-entry,
|
||||
// not the session-wide bit) bypasses — so accepting the dev dialog for
|
||||
// one entry doesn't leak allowlist-bypass to --channels entries.
|
||||
if (!entry.dev) {
|
||||
const { entries, source } = getEffectiveChannelAllowlist(
|
||||
sub,
|
||||
policy?.allowedChannelPlugins,
|
||||
)
|
||||
if (
|
||||
!entries.some(
|
||||
e => e.plugin === entry.name && e.marketplace === entry.marketplace,
|
||||
)
|
||||
) {
|
||||
return {
|
||||
action: 'skip',
|
||||
kind: 'allowlist',
|
||||
reason:
|
||||
source === 'org'
|
||||
? `plugin ${entry.name}@${entry.marketplace} is not on your org's approved channels list (set allowedChannelPlugins in managed settings)`
|
||||
: `plugin ${entry.name}@${entry.marketplace} is not on the approved channels allowlist (use --dangerously-load-development-channels for local dev)`,
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// server-kind: allowlist schema is {marketplace, plugin} — a server entry
|
||||
// can never match. Without this, --channels server:plugin:foo:bar would
|
||||
// match a plugin's runtime name and register with no allowlist check.
|
||||
if (!entry.dev) {
|
||||
return {
|
||||
action: 'skip',
|
||||
kind: 'allowlist',
|
||||
reason: `server ${entry.name} is not on the approved channels allowlist (use --dangerously-load-development-channels for local dev)`,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { action: 'register' }
|
||||
}
|
||||
240
src/services/mcp/channelPermissions.ts
Normal file
240
src/services/mcp/channelPermissions.ts
Normal file
@@ -0,0 +1,240 @@
|
||||
/**
|
||||
* Permission prompts over channels (Telegram, iMessage, Discord).
|
||||
*
|
||||
* Mirrors `BridgePermissionCallbacks` — when CC hits a permission dialog,
|
||||
* it ALSO sends the prompt via active channels and races the reply against
|
||||
* local UI / bridge / hooks / classifier. First resolver wins via claim().
|
||||
*
|
||||
* Inbound is a structured event: the server parses the user's "yes tbxkq"
|
||||
* reply and emits notifications/claude/channel/permission with
|
||||
* {request_id, behavior}. CC never sees the reply as text — approval
|
||||
* requires the server to deliberately emit that specific event, not just
|
||||
* relay content. Servers opt in by declaring
|
||||
* capabilities.experimental['claude/channel/permission'].
|
||||
*
|
||||
* Kenneth's "would this let Claude self-approve?": the approving party is
|
||||
* the human via the channel, not Claude. But the trust boundary isn't the
|
||||
* terminal — it's the allowlist (tengu_harbor_ledger). A compromised
|
||||
* channel server CAN fabricate "yes <id>" without the human seeing the
|
||||
* prompt. Accepted risk: a compromised channel already has unlimited
|
||||
* conversation-injection turns (social-engineer over time, wait for
|
||||
* acceptEdits, etc.); inject-then-self-approve is faster, not more
|
||||
* capable. The dialog slows a compromised channel; it doesn't stop one.
|
||||
* See PR discussion 2956440848.
|
||||
*/
|
||||
|
||||
import { jsonStringify } from '../../utils/slowOperations.js'
|
||||
import { getFeatureValue_CACHED_MAY_BE_STALE } from '../analytics/growthbook.js'
|
||||
|
||||
/**
|
||||
* GrowthBook runtime gate — separate from the channels gate (tengu_harbor)
|
||||
* so channels can ship without permission-relay riding along (Kenneth: "no
|
||||
* bake time if it goes out tomorrow"). Default false; flip without a release.
|
||||
* Checked once at useManageMCPConnections mount — mid-session flag changes
|
||||
* don't apply until restart.
|
||||
*/
|
||||
export function isChannelPermissionRelayEnabled(): boolean {
|
||||
return getFeatureValue_CACHED_MAY_BE_STALE('tengu_harbor_permissions', false)
|
||||
}
|
||||
|
||||
export type ChannelPermissionResponse = {
|
||||
behavior: 'allow' | 'deny'
|
||||
/** Which channel server the reply came from (e.g., "plugin:telegram:tg"). */
|
||||
fromServer: string
|
||||
}
|
||||
|
||||
export type ChannelPermissionCallbacks = {
|
||||
/** Register a resolver for a request ID. Returns unsubscribe. */
|
||||
onResponse(
|
||||
requestId: string,
|
||||
handler: (response: ChannelPermissionResponse) => void,
|
||||
): () => void
|
||||
/** Resolve a pending request from a structured channel event
|
||||
* (notifications/claude/channel/permission). Returns true if the ID
|
||||
* was pending — the server parsed the user's reply and emitted
|
||||
* {request_id, behavior}; we just match against the map. */
|
||||
resolve(
|
||||
requestId: string,
|
||||
behavior: 'allow' | 'deny',
|
||||
fromServer: string,
|
||||
): boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* Reply format spec for channel servers to implement:
|
||||
* /^\s*(y|yes|n|no)\s+([a-km-z]{5})\s*$/i
|
||||
*
|
||||
* 5 lowercase letters, no 'l' (looks like 1/I). Case-insensitive (phone
|
||||
* autocorrect). No bare yes/no (conversational). No prefix/suffix chatter.
|
||||
*
|
||||
* CC generates the ID and sends the prompt. The SERVER parses the user's
|
||||
* reply and emits notifications/claude/channel/permission with {request_id,
|
||||
* behavior} — CC doesn't regex-match text anymore. Exported so plugins can
|
||||
* import the exact regex rather than hand-copying it.
|
||||
*/
|
||||
export const PERMISSION_REPLY_RE = /^\s*(y|yes|n|no)\s+([a-km-z]{5})\s*$/i
|
||||
|
||||
// 25-letter alphabet: a-z minus 'l' (looks like 1/I). 25^5 ≈ 9.8M space.
|
||||
const ID_ALPHABET = 'abcdefghijkmnopqrstuvwxyz'
|
||||
|
||||
// Substring blocklist — 5 random letters can spell things (Kenneth, in the
|
||||
// launch thread: "this is why i bias to numbers, hard to have anything worse
|
||||
// than 80085"). Non-exhaustive, covers the send-to-your-boss-by-accident
|
||||
// tier. If a generated ID contains any of these, re-hash with a salt.
|
||||
// prettier-ignore
|
||||
const ID_AVOID_SUBSTRINGS = [
|
||||
'fuck',
|
||||
'shit',
|
||||
'cunt',
|
||||
'cock',
|
||||
'dick',
|
||||
'twat',
|
||||
'piss',
|
||||
'crap',
|
||||
'bitch',
|
||||
'whore',
|
||||
'ass',
|
||||
'tit',
|
||||
'cum',
|
||||
'fag',
|
||||
'dyke',
|
||||
'nig',
|
||||
'kike',
|
||||
'rape',
|
||||
'nazi',
|
||||
'damn',
|
||||
'poo',
|
||||
'pee',
|
||||
'wank',
|
||||
'anus',
|
||||
]
|
||||
|
||||
function hashToId(input: string): string {
|
||||
// FNV-1a → uint32, then base-25 encode. Not crypto, just a stable
|
||||
// short letters-only ID. 32 bits / log2(25) ≈ 6.9 letters of entropy;
|
||||
// taking 5 wastes a little, plenty for this.
|
||||
let h = 0x811c9dc5
|
||||
for (let i = 0; i < input.length; i++) {
|
||||
h ^= input.charCodeAt(i)
|
||||
h = Math.imul(h, 0x01000193)
|
||||
}
|
||||
h = h >>> 0
|
||||
let s = ''
|
||||
for (let i = 0; i < 5; i++) {
|
||||
s += ID_ALPHABET[h % 25]
|
||||
h = Math.floor(h / 25)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
/**
|
||||
* Short ID from a toolUseID. 5 letters from a 25-char alphabet (a-z minus
|
||||
* 'l' — looks like 1/I in many fonts). 25^5 ≈ 9.8M space, birthday
|
||||
* collision at 50% needs ~3K simultaneous pending prompts, absurd for a
|
||||
* single interactive session. Letters-only so phone users don't switch
|
||||
* keyboard modes (hex alternates a-f/0-9 → mode toggles). Re-hashes with
|
||||
* a salt suffix if the result contains a blocklisted substring — 5 random
|
||||
* letters can spell things you don't want in a text message to your phone.
|
||||
* toolUseIDs are `toolu_` + base64-ish; we hash rather than slice.
|
||||
*/
|
||||
export function shortRequestId(toolUseID: string): string {
|
||||
// 7 length-3 × 3 positions × 25² + 15 length-4 × 2 × 25 + 2 length-5
|
||||
// ≈ 13,877 blocked IDs out of 9.8M — roughly 1 in 700 hits the blocklist.
|
||||
// Cap at 10 retries; (1/700)^10 is negligible.
|
||||
let candidate = hashToId(toolUseID)
|
||||
for (let salt = 0; salt < 10; salt++) {
|
||||
if (!ID_AVOID_SUBSTRINGS.some(bad => candidate.includes(bad))) {
|
||||
return candidate
|
||||
}
|
||||
candidate = hashToId(`${toolUseID}:${salt}`)
|
||||
}
|
||||
return candidate
|
||||
}
|
||||
|
||||
/**
|
||||
* Truncate tool input to a phone-sized JSON preview. 200 chars is
|
||||
* roughly 3 lines on a narrow phone screen. Full input is in the local
|
||||
* terminal dialog; the channel gets a summary so Write(5KB-file) doesn't
|
||||
* flood your texts. Server decides whether/how to show it.
|
||||
*/
|
||||
export function truncateForPreview(input: unknown): string {
|
||||
try {
|
||||
const s = jsonStringify(input)
|
||||
return s.length > 200 ? s.slice(0, 200) + '…' : s
|
||||
} catch {
|
||||
return '(unserializable)'
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Filter MCP clients down to those that can relay permission prompts.
|
||||
* Three conditions, ALL required: connected + in the session's --channels
|
||||
* allowlist + declares BOTH capabilities. The second capability is the
|
||||
* server's explicit opt-in — a relay-only channel never becomes a
|
||||
* permission surface by accident (Kenneth's "users may be unpleasantly
|
||||
* surprised"). Centralized here so a future fourth condition lands once.
|
||||
*/
|
||||
export function filterPermissionRelayClients<
|
||||
T extends {
|
||||
type: string
|
||||
name: string
|
||||
capabilities?: { experimental?: Record<string, unknown> }
|
||||
},
|
||||
>(
|
||||
clients: readonly T[],
|
||||
isInAllowlist: (name: string) => boolean,
|
||||
): (T & { type: 'connected' })[] {
|
||||
return clients.filter(
|
||||
(c): c is T & { type: 'connected' } =>
|
||||
c.type === 'connected' &&
|
||||
isInAllowlist(c.name) &&
|
||||
c.capabilities?.experimental?.['claude/channel'] !== undefined &&
|
||||
c.capabilities?.experimental?.['claude/channel/permission'] !== undefined,
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Factory for the callbacks object. The pending Map is closed over — NOT
|
||||
* module-level (per src/CLAUDE.md), NOT in AppState (functions-in-state
|
||||
* causes issues with equality/serialization). Same lifetime pattern as
|
||||
* `replBridgePermissionCallbacks`: constructed once per session inside
|
||||
* a React hook, stable reference stored in AppState.
|
||||
*
|
||||
* resolve() is called from the dedicated notification handler
|
||||
* (notifications/claude/channel/permission) with the structured payload.
|
||||
* The server already parsed "yes tbxkq" → {request_id, behavior}; we just
|
||||
* match against the pending map. No regex on CC's side — text in the
|
||||
* general channel can't accidentally approve anything.
|
||||
*/
|
||||
export function createChannelPermissionCallbacks(): ChannelPermissionCallbacks {
|
||||
const pending = new Map<
|
||||
string,
|
||||
(response: ChannelPermissionResponse) => void
|
||||
>()
|
||||
|
||||
return {
|
||||
onResponse(requestId, handler) {
|
||||
// Lowercase here too — resolve() already does; asymmetry means a
|
||||
// future caller passing a mixed-case ID would silently never match.
|
||||
// shortRequestId always emits lowercase so this is a noop today,
|
||||
// but the symmetry makes the contract explicit.
|
||||
const key = requestId.toLowerCase()
|
||||
pending.set(key, handler)
|
||||
return () => {
|
||||
pending.delete(key)
|
||||
}
|
||||
},
|
||||
|
||||
resolve(requestId, behavior, fromServer) {
|
||||
const key = requestId.toLowerCase()
|
||||
const resolver = pending.get(key)
|
||||
if (!resolver) return false
|
||||
// Delete BEFORE calling — if resolver throws or re-enters, the
|
||||
// entry is already gone. Also handles duplicate events (second
|
||||
// emission falls through — server bug or network dup, ignore).
|
||||
pending.delete(key)
|
||||
resolver({ behavior, fromServer })
|
||||
return true
|
||||
},
|
||||
}
|
||||
}
|
||||
164
src/services/mcp/claudeai.ts
Normal file
164
src/services/mcp/claudeai.ts
Normal file
@@ -0,0 +1,164 @@
|
||||
import axios from 'axios'
|
||||
import memoize from 'lodash-es/memoize.js'
|
||||
import { getOauthConfig } from 'src/constants/oauth.js'
|
||||
import {
|
||||
type AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
logEvent,
|
||||
} from 'src/services/analytics/index.js'
|
||||
import { getClaudeAIOAuthTokens } from 'src/utils/auth.js'
|
||||
import { getGlobalConfig, saveGlobalConfig } from 'src/utils/config.js'
|
||||
import { logForDebugging } from 'src/utils/debug.js'
|
||||
import { isEnvDefinedFalsy } from 'src/utils/envUtils.js'
|
||||
import { clearMcpAuthCache } from './client.js'
|
||||
import { normalizeNameForMCP } from './normalization.js'
|
||||
import type { ScopedMcpServerConfig } from './types.js'
|
||||
|
||||
type ClaudeAIMcpServer = {
|
||||
type: 'mcp_server'
|
||||
id: string
|
||||
display_name: string
|
||||
url: string
|
||||
created_at: string
|
||||
}
|
||||
|
||||
type ClaudeAIMcpServersResponse = {
|
||||
data: ClaudeAIMcpServer[]
|
||||
has_more: boolean
|
||||
next_page: string | null
|
||||
}
|
||||
|
||||
const FETCH_TIMEOUT_MS = 5000
|
||||
const MCP_SERVERS_BETA_HEADER = 'mcp-servers-2025-12-04'
|
||||
|
||||
/**
|
||||
* Fetches MCP server configurations from Claude.ai org configs.
|
||||
* These servers are managed by the organization via Claude.ai.
|
||||
*
|
||||
* Results are memoized for the session lifetime (fetch once per CLI session).
|
||||
*/
|
||||
export const fetchClaudeAIMcpConfigsIfEligible = memoize(
|
||||
async (): Promise<Record<string, ScopedMcpServerConfig>> => {
|
||||
try {
|
||||
if (isEnvDefinedFalsy(process.env.ENABLE_CLAUDEAI_MCP_SERVERS)) {
|
||||
logForDebugging('[claudeai-mcp] Disabled via env var')
|
||||
logEvent('tengu_claudeai_mcp_eligibility', {
|
||||
state:
|
||||
'disabled_env_var' as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
return {}
|
||||
}
|
||||
|
||||
const tokens = getClaudeAIOAuthTokens()
|
||||
if (!tokens?.accessToken) {
|
||||
logForDebugging('[claudeai-mcp] No access token')
|
||||
logEvent('tengu_claudeai_mcp_eligibility', {
|
||||
state:
|
||||
'no_oauth_token' as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
return {}
|
||||
}
|
||||
|
||||
// Check for user:mcp_servers scope directly instead of isClaudeAISubscriber().
|
||||
// In non-interactive mode, isClaudeAISubscriber() returns false when ANTHROPIC_API_KEY
|
||||
// is set (even with valid OAuth tokens) because preferThirdPartyAuthentication() causes
|
||||
// isAnthropicAuthEnabled() to return false. Checking the scope directly allows users
|
||||
// with both API keys and OAuth tokens to access claude.ai MCPs in print mode.
|
||||
if (!tokens.scopes?.includes('user:mcp_servers')) {
|
||||
logForDebugging(
|
||||
`[claudeai-mcp] Missing user:mcp_servers scope (scopes=${tokens.scopes?.join(',') || 'none'})`,
|
||||
)
|
||||
logEvent('tengu_claudeai_mcp_eligibility', {
|
||||
state:
|
||||
'missing_scope' as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
return {}
|
||||
}
|
||||
|
||||
const baseUrl = getOauthConfig().BASE_API_URL
|
||||
const url = `${baseUrl}/v1/mcp_servers?limit=1000`
|
||||
|
||||
logForDebugging(`[claudeai-mcp] Fetching from ${url}`)
|
||||
|
||||
const response = await axios.get<ClaudeAIMcpServersResponse>(url, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${tokens.accessToken}`,
|
||||
'Content-Type': 'application/json',
|
||||
'anthropic-beta': MCP_SERVERS_BETA_HEADER,
|
||||
'anthropic-version': '2023-06-01',
|
||||
},
|
||||
timeout: FETCH_TIMEOUT_MS,
|
||||
})
|
||||
|
||||
const configs: Record<string, ScopedMcpServerConfig> = {}
|
||||
// Track used normalized names to detect collisions and assign (2), (3), etc. suffixes.
|
||||
// We check the final normalized name (including suffix) to handle edge cases where
|
||||
// a suffixed name collides with another server's base name (e.g., "Example Server 2"
|
||||
// colliding with "Example Server! (2)" which both normalize to claude_ai_Example_Server_2).
|
||||
const usedNormalizedNames = new Set<string>()
|
||||
|
||||
for (const server of response.data.data) {
|
||||
const baseName = `claude.ai ${server.display_name}`
|
||||
|
||||
// Try without suffix first, then increment until we find an unused normalized name
|
||||
let finalName = baseName
|
||||
let finalNormalized = normalizeNameForMCP(finalName)
|
||||
let count = 1
|
||||
while (usedNormalizedNames.has(finalNormalized)) {
|
||||
count++
|
||||
finalName = `${baseName} (${count})`
|
||||
finalNormalized = normalizeNameForMCP(finalName)
|
||||
}
|
||||
usedNormalizedNames.add(finalNormalized)
|
||||
|
||||
configs[finalName] = {
|
||||
type: 'claudeai-proxy',
|
||||
url: server.url,
|
||||
id: server.id,
|
||||
scope: 'claudeai',
|
||||
}
|
||||
}
|
||||
|
||||
logForDebugging(
|
||||
`[claudeai-mcp] Fetched ${Object.keys(configs).length} servers`,
|
||||
)
|
||||
logEvent('tengu_claudeai_mcp_eligibility', {
|
||||
state:
|
||||
'eligible' as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
return configs
|
||||
} catch {
|
||||
logForDebugging(`[claudeai-mcp] Fetch failed`)
|
||||
return {}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
/**
|
||||
* Clears the memoized cache for fetchClaudeAIMcpConfigsIfEligible.
|
||||
* Call this after login so the next fetch will use the new auth tokens.
|
||||
*/
|
||||
export function clearClaudeAIMcpConfigsCache(): void {
|
||||
fetchClaudeAIMcpConfigsIfEligible.cache.clear?.()
|
||||
// Also clear the auth cache so freshly-authorized servers get re-connected
|
||||
clearMcpAuthCache()
|
||||
}
|
||||
|
||||
/**
|
||||
* Record that a claude.ai connector successfully connected. Idempotent.
|
||||
*
|
||||
* Gates the "N connectors unavailable/need auth" startup notifications: a
|
||||
* connector that was working yesterday and is now failed is a state change
|
||||
* worth surfacing; an org-configured connector that's been needs-auth since
|
||||
* it showed up is one the user has demonstrably ignored.
|
||||
*/
|
||||
export function markClaudeAiMcpConnected(name: string): void {
|
||||
saveGlobalConfig(current => {
|
||||
const seen = current.claudeAiMcpEverConnected ?? []
|
||||
if (seen.includes(name)) return current
|
||||
return { ...current, claudeAiMcpEverConnected: [...seen, name] }
|
||||
})
|
||||
}
|
||||
|
||||
export function hasClaudeAiMcpEverConnected(name: string): boolean {
|
||||
return (getGlobalConfig().claudeAiMcpEverConnected ?? []).includes(name)
|
||||
}
|
||||
3348
src/services/mcp/client.ts
Normal file
3348
src/services/mcp/client.ts
Normal file
File diff suppressed because it is too large
Load Diff
1578
src/services/mcp/config.ts
Normal file
1578
src/services/mcp/config.ts
Normal file
File diff suppressed because it is too large
Load Diff
313
src/services/mcp/elicitationHandler.ts
Normal file
313
src/services/mcp/elicitationHandler.ts
Normal file
@@ -0,0 +1,313 @@
|
||||
import type { Client } from '@modelcontextprotocol/sdk/client/index.js'
|
||||
import {
|
||||
ElicitationCompleteNotificationSchema,
|
||||
type ElicitRequestParams,
|
||||
ElicitRequestSchema,
|
||||
type ElicitResult,
|
||||
} from '@modelcontextprotocol/sdk/types.js'
|
||||
import type { AppState } from '../../state/AppState.js'
|
||||
import {
|
||||
executeElicitationHooks,
|
||||
executeElicitationResultHooks,
|
||||
executeNotificationHooks,
|
||||
} from '../../utils/hooks.js'
|
||||
import { logMCPDebug, logMCPError } from '../../utils/log.js'
|
||||
import { jsonStringify } from '../../utils/slowOperations.js'
|
||||
import {
|
||||
type AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
logEvent,
|
||||
} from '../analytics/index.js'
|
||||
|
||||
/** Configuration for the waiting state shown after the user opens a URL. */
|
||||
export type ElicitationWaitingState = {
|
||||
/** Button label, e.g. "Retry now" or "Skip confirmation" */
|
||||
actionLabel: string
|
||||
/** Whether to show a visible Cancel button (e.g. for error-based retry flow) */
|
||||
showCancel?: boolean
|
||||
}
|
||||
|
||||
export type ElicitationRequestEvent = {
|
||||
serverName: string
|
||||
/** The JSON-RPC request ID, unique per server connection. */
|
||||
requestId: string | number
|
||||
params: ElicitRequestParams
|
||||
signal: AbortSignal
|
||||
/**
|
||||
* Resolves the elicitation. For explicit elicitations, all actions are
|
||||
* meaningful. For error-based retry (-32042), 'accept' is a no-op —
|
||||
* the retry is driven by onWaitingDismiss instead.
|
||||
*/
|
||||
respond: (response: ElicitResult) => void
|
||||
/** For URL elicitations: shown after user opens the browser. */
|
||||
waitingState?: ElicitationWaitingState
|
||||
/** Called when phase 2 (waiting) is dismissed by user action or completion. */
|
||||
onWaitingDismiss?: (action: 'dismiss' | 'retry' | 'cancel') => void
|
||||
/** Set to true by the completion notification handler when the server confirms completion. */
|
||||
completed?: boolean
|
||||
}
|
||||
|
||||
function getElicitationMode(params: ElicitRequestParams): 'form' | 'url' {
|
||||
return params.mode === 'url' ? 'url' : 'form'
|
||||
}
|
||||
|
||||
/** Find a queued elicitation event by server name and elicitationId. */
|
||||
function findElicitationInQueue(
|
||||
queue: ElicitationRequestEvent[],
|
||||
serverName: string,
|
||||
elicitationId: string,
|
||||
): number {
|
||||
return queue.findIndex(
|
||||
e =>
|
||||
e.serverName === serverName &&
|
||||
e.params.mode === 'url' &&
|
||||
'elicitationId' in e.params &&
|
||||
e.params.elicitationId === elicitationId,
|
||||
)
|
||||
}
|
||||
|
||||
export function registerElicitationHandler(
|
||||
client: Client,
|
||||
serverName: string,
|
||||
setAppState: (f: (prevState: AppState) => AppState) => void,
|
||||
): void {
|
||||
// Register the elicitation request handler.
|
||||
// Wrapped in try/catch because setRequestHandler throws if the client wasn't
|
||||
// created with elicitation capability declared.
|
||||
try {
|
||||
client.setRequestHandler(ElicitRequestSchema, async (request, extra) => {
|
||||
logMCPDebug(
|
||||
serverName,
|
||||
`Received elicitation request: ${jsonStringify(request)}`,
|
||||
)
|
||||
|
||||
const mode = getElicitationMode(request.params)
|
||||
|
||||
logEvent('tengu_mcp_elicitation_shown', {
|
||||
mode: mode as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
|
||||
try {
|
||||
// Run elicitation hooks first - they can provide a response programmatically
|
||||
const hookResponse = await runElicitationHooks(
|
||||
serverName,
|
||||
request.params,
|
||||
extra.signal,
|
||||
)
|
||||
if (hookResponse) {
|
||||
logMCPDebug(
|
||||
serverName,
|
||||
`Elicitation resolved by hook: ${jsonStringify(hookResponse)}`,
|
||||
)
|
||||
logEvent('tengu_mcp_elicitation_response', {
|
||||
mode: mode as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
action:
|
||||
hookResponse.action as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
return hookResponse
|
||||
}
|
||||
|
||||
const elicitationId =
|
||||
mode === 'url' && 'elicitationId' in request.params
|
||||
? (request.params.elicitationId as string | undefined)
|
||||
: undefined
|
||||
|
||||
const response = new Promise<ElicitResult>(resolve => {
|
||||
const onAbort = () => {
|
||||
resolve({ action: 'cancel' })
|
||||
}
|
||||
|
||||
if (extra.signal.aborted) {
|
||||
onAbort()
|
||||
return
|
||||
}
|
||||
|
||||
const waitingState: ElicitationWaitingState | undefined =
|
||||
elicitationId ? { actionLabel: 'Skip confirmation' } : undefined
|
||||
|
||||
setAppState(prev => ({
|
||||
...prev,
|
||||
elicitation: {
|
||||
queue: [
|
||||
...prev.elicitation.queue,
|
||||
{
|
||||
serverName,
|
||||
requestId: extra.requestId,
|
||||
params: request.params,
|
||||
signal: extra.signal,
|
||||
waitingState,
|
||||
respond: (result: ElicitResult) => {
|
||||
extra.signal.removeEventListener('abort', onAbort)
|
||||
logEvent('tengu_mcp_elicitation_response', {
|
||||
mode: mode as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
action:
|
||||
result.action as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
resolve(result)
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
}))
|
||||
|
||||
extra.signal.addEventListener('abort', onAbort, { once: true })
|
||||
})
|
||||
const rawResult = await response
|
||||
logMCPDebug(
|
||||
serverName,
|
||||
`Elicitation response: ${jsonStringify(rawResult)}`,
|
||||
)
|
||||
const result = await runElicitationResultHooks(
|
||||
serverName,
|
||||
rawResult,
|
||||
extra.signal,
|
||||
mode,
|
||||
elicitationId,
|
||||
)
|
||||
return result
|
||||
} catch (error) {
|
||||
logMCPError(serverName, `Elicitation error: ${error}`)
|
||||
return { action: 'cancel' as const }
|
||||
}
|
||||
})
|
||||
|
||||
// Register handler for elicitation completion notifications (URL mode).
|
||||
// Sets `completed: true` on the matching queue event; the dialog reacts to this flag.
|
||||
client.setNotificationHandler(
|
||||
ElicitationCompleteNotificationSchema,
|
||||
notification => {
|
||||
const { elicitationId } = notification.params
|
||||
logMCPDebug(
|
||||
serverName,
|
||||
`Received elicitation completion notification: ${elicitationId}`,
|
||||
)
|
||||
void executeNotificationHooks({
|
||||
message: `MCP server "${serverName}" confirmed elicitation ${elicitationId} complete`,
|
||||
notificationType: 'elicitation_complete',
|
||||
})
|
||||
let found = false
|
||||
setAppState(prev => {
|
||||
const idx = findElicitationInQueue(
|
||||
prev.elicitation.queue,
|
||||
serverName,
|
||||
elicitationId,
|
||||
)
|
||||
if (idx === -1) return prev
|
||||
found = true
|
||||
const queue = [...prev.elicitation.queue]
|
||||
queue[idx] = { ...queue[idx]!, completed: true }
|
||||
return { ...prev, elicitation: { queue } }
|
||||
})
|
||||
if (!found) {
|
||||
logMCPDebug(
|
||||
serverName,
|
||||
`Ignoring completion notification for unknown elicitation: ${elicitationId}`,
|
||||
)
|
||||
}
|
||||
},
|
||||
)
|
||||
} catch {
|
||||
// Client wasn't created with elicitation capability - nothing to register
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
export async function runElicitationHooks(
|
||||
serverName: string,
|
||||
params: ElicitRequestParams,
|
||||
signal: AbortSignal,
|
||||
): Promise<ElicitResult | undefined> {
|
||||
try {
|
||||
const mode = params.mode === 'url' ? 'url' : 'form'
|
||||
const url = 'url' in params ? (params.url as string) : undefined
|
||||
const elicitationId =
|
||||
'elicitationId' in params
|
||||
? (params.elicitationId as string | undefined)
|
||||
: undefined
|
||||
|
||||
const { elicitationResponse, blockingError } =
|
||||
await executeElicitationHooks({
|
||||
serverName,
|
||||
message: params.message,
|
||||
requestedSchema:
|
||||
'requestedSchema' in params
|
||||
? (params.requestedSchema as Record<string, unknown>)
|
||||
: undefined,
|
||||
signal,
|
||||
mode,
|
||||
url,
|
||||
elicitationId,
|
||||
})
|
||||
|
||||
if (blockingError) {
|
||||
return { action: 'decline' }
|
||||
}
|
||||
|
||||
if (elicitationResponse) {
|
||||
return {
|
||||
action: elicitationResponse.action,
|
||||
content: elicitationResponse.content,
|
||||
}
|
||||
}
|
||||
|
||||
return undefined
|
||||
} catch (error) {
|
||||
logMCPError(serverName, `Elicitation hook error: ${error}`)
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Run ElicitationResult hooks after the user has responded, then fire a
|
||||
* `elicitation_response` notification. Returns a (potentially modified)
|
||||
* ElicitResult — hooks may override the action/content or block the response.
|
||||
*/
|
||||
export async function runElicitationResultHooks(
|
||||
serverName: string,
|
||||
result: ElicitResult,
|
||||
signal: AbortSignal,
|
||||
mode?: 'form' | 'url',
|
||||
elicitationId?: string,
|
||||
): Promise<ElicitResult> {
|
||||
try {
|
||||
const { elicitationResultResponse, blockingError } =
|
||||
await executeElicitationResultHooks({
|
||||
serverName,
|
||||
action: result.action,
|
||||
content: result.content as Record<string, unknown> | undefined,
|
||||
signal,
|
||||
mode,
|
||||
elicitationId,
|
||||
})
|
||||
|
||||
if (blockingError) {
|
||||
void executeNotificationHooks({
|
||||
message: `Elicitation response for server "${serverName}": decline`,
|
||||
notificationType: 'elicitation_response',
|
||||
})
|
||||
return { action: 'decline' }
|
||||
}
|
||||
|
||||
const finalResult = elicitationResultResponse
|
||||
? {
|
||||
action: elicitationResultResponse.action,
|
||||
content: elicitationResultResponse.content ?? result.content,
|
||||
}
|
||||
: result
|
||||
|
||||
// Fire a notification for observability
|
||||
void executeNotificationHooks({
|
||||
message: `Elicitation response for server "${serverName}": ${finalResult.action}`,
|
||||
notificationType: 'elicitation_response',
|
||||
})
|
||||
|
||||
return finalResult
|
||||
} catch (error) {
|
||||
logMCPError(serverName, `ElicitationResult hook error: ${error}`)
|
||||
// Fire notification even on error
|
||||
void executeNotificationHooks({
|
||||
message: `Elicitation response for server "${serverName}": ${result.action}`,
|
||||
notificationType: 'elicitation_response',
|
||||
})
|
||||
return result
|
||||
}
|
||||
}
|
||||
38
src/services/mcp/envExpansion.ts
Normal file
38
src/services/mcp/envExpansion.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
/**
|
||||
* Shared utilities for expanding environment variables in MCP server configurations
|
||||
*/
|
||||
|
||||
/**
|
||||
* Expand environment variables in a string value
|
||||
* Handles ${VAR} and ${VAR:-default} syntax
|
||||
* @returns Object with expanded string and list of missing variables
|
||||
*/
|
||||
export function expandEnvVarsInString(value: string): {
|
||||
expanded: string
|
||||
missingVars: string[]
|
||||
} {
|
||||
const missingVars: string[] = []
|
||||
|
||||
const expanded = value.replace(/\$\{([^}]+)\}/g, (match, varContent) => {
|
||||
// Split on :- to support default values (limit to 2 parts to preserve :- in defaults)
|
||||
const [varName, defaultValue] = varContent.split(':-', 2)
|
||||
const envValue = process.env[varName]
|
||||
|
||||
if (envValue !== undefined) {
|
||||
return envValue
|
||||
}
|
||||
if (defaultValue !== undefined) {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// Track missing variable for error reporting
|
||||
missingVars.push(varName)
|
||||
// Return original if not found (allows debugging but will be reported as error)
|
||||
return match
|
||||
})
|
||||
|
||||
return {
|
||||
expanded,
|
||||
missingVars,
|
||||
}
|
||||
}
|
||||
138
src/services/mcp/headersHelper.ts
Normal file
138
src/services/mcp/headersHelper.ts
Normal file
@@ -0,0 +1,138 @@
|
||||
import { getIsNonInteractiveSession } from '../../bootstrap/state.js'
|
||||
import { checkHasTrustDialogAccepted } from '../../utils/config.js'
|
||||
import { logAntError } from '../../utils/debug.js'
|
||||
import { errorMessage } from '../../utils/errors.js'
|
||||
import { execFileNoThrowWithCwd } from '../../utils/execFileNoThrow.js'
|
||||
import { logError, logMCPDebug, logMCPError } from '../../utils/log.js'
|
||||
import { jsonParse } from '../../utils/slowOperations.js'
|
||||
import { logEvent } from '../analytics/index.js'
|
||||
import type {
|
||||
McpHTTPServerConfig,
|
||||
McpSSEServerConfig,
|
||||
McpWebSocketServerConfig,
|
||||
ScopedMcpServerConfig,
|
||||
} from './types.js'
|
||||
|
||||
/**
|
||||
* Check if the MCP server config comes from project settings (projectSettings or localSettings)
|
||||
* This is important for security checks
|
||||
*/
|
||||
function isMcpServerFromProjectOrLocalSettings(
|
||||
config: ScopedMcpServerConfig,
|
||||
): boolean {
|
||||
return config.scope === 'project' || config.scope === 'local'
|
||||
}
|
||||
|
||||
/**
|
||||
* Get dynamic headers for an MCP server using the headersHelper script
|
||||
* @param serverName The name of the MCP server
|
||||
* @param config The MCP server configuration
|
||||
* @returns Headers object or null if not configured or failed
|
||||
*/
|
||||
export async function getMcpHeadersFromHelper(
|
||||
serverName: string,
|
||||
config: McpSSEServerConfig | McpHTTPServerConfig | McpWebSocketServerConfig,
|
||||
): Promise<Record<string, string> | null> {
|
||||
if (!config.headersHelper) {
|
||||
return null
|
||||
}
|
||||
|
||||
// Security check for project/local settings
|
||||
// Skip trust check in non-interactive mode (e.g., CI/CD, automation)
|
||||
if (
|
||||
'scope' in config &&
|
||||
isMcpServerFromProjectOrLocalSettings(config as ScopedMcpServerConfig) &&
|
||||
!getIsNonInteractiveSession()
|
||||
) {
|
||||
// Check if trust has been established for this project
|
||||
const hasTrust = checkHasTrustDialogAccepted()
|
||||
if (!hasTrust) {
|
||||
const error = new Error(
|
||||
`Security: headersHelper for MCP server '${serverName}' executed before workspace trust is confirmed. If you see this message, post in ${MACRO.FEEDBACK_CHANNEL}.`,
|
||||
)
|
||||
logAntError('MCP headersHelper invoked before trust check', error)
|
||||
logEvent('tengu_mcp_headersHelper_missing_trust', {})
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
logMCPDebug(serverName, 'Executing headersHelper to get dynamic headers')
|
||||
const execResult = await execFileNoThrowWithCwd(config.headersHelper, [], {
|
||||
shell: true,
|
||||
timeout: 10000,
|
||||
// Pass server context so one helper script can serve multiple MCP servers
|
||||
// (git credential-helper style). See deshaw/anthropic-issues#28.
|
||||
env: {
|
||||
...process.env,
|
||||
CLAUDE_CODE_MCP_SERVER_NAME: serverName,
|
||||
CLAUDE_CODE_MCP_SERVER_URL: config.url,
|
||||
},
|
||||
})
|
||||
if (execResult.code !== 0 || !execResult.stdout) {
|
||||
throw new Error(
|
||||
`headersHelper for MCP server '${serverName}' did not return a valid value`,
|
||||
)
|
||||
}
|
||||
const result = execResult.stdout.trim()
|
||||
|
||||
const headers = jsonParse(result)
|
||||
if (
|
||||
typeof headers !== 'object' ||
|
||||
headers === null ||
|
||||
Array.isArray(headers)
|
||||
) {
|
||||
throw new Error(
|
||||
`headersHelper for MCP server '${serverName}' must return a JSON object with string key-value pairs`,
|
||||
)
|
||||
}
|
||||
|
||||
// Validate all values are strings
|
||||
for (const [key, value] of Object.entries(headers)) {
|
||||
if (typeof value !== 'string') {
|
||||
throw new Error(
|
||||
`headersHelper for MCP server '${serverName}' returned non-string value for key "${key}": ${typeof value}`,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
logMCPDebug(
|
||||
serverName,
|
||||
`Successfully retrieved ${Object.keys(headers).length} headers from headersHelper`,
|
||||
)
|
||||
return headers as Record<string, string>
|
||||
} catch (error) {
|
||||
logMCPError(
|
||||
serverName,
|
||||
`Error getting headers from headersHelper: ${errorMessage(error)}`,
|
||||
)
|
||||
logError(
|
||||
new Error(
|
||||
`Error getting MCP headers from headersHelper for server '${serverName}': ${errorMessage(error)}`,
|
||||
),
|
||||
)
|
||||
// Return null instead of throwing to avoid blocking the connection
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get combined headers for an MCP server (static + dynamic)
|
||||
* @param serverName The name of the MCP server
|
||||
* @param config The MCP server configuration
|
||||
* @returns Combined headers object
|
||||
*/
|
||||
export async function getMcpServerHeaders(
|
||||
serverName: string,
|
||||
config: McpSSEServerConfig | McpHTTPServerConfig | McpWebSocketServerConfig,
|
||||
): Promise<Record<string, string>> {
|
||||
const staticHeaders = config.headers || {}
|
||||
const dynamicHeaders =
|
||||
(await getMcpHeadersFromHelper(serverName, config)) || {}
|
||||
|
||||
// Dynamic headers override static headers if both are present
|
||||
return {
|
||||
...staticHeaders,
|
||||
...dynamicHeaders,
|
||||
}
|
||||
}
|
||||
106
src/services/mcp/mcpStringUtils.ts
Normal file
106
src/services/mcp/mcpStringUtils.ts
Normal file
@@ -0,0 +1,106 @@
|
||||
/**
|
||||
* Pure string utility functions for MCP tool/server name parsing.
|
||||
* This file has no heavy dependencies to keep it lightweight for
|
||||
* consumers that only need string parsing (e.g., permissionValidation).
|
||||
*/
|
||||
|
||||
import { normalizeNameForMCP } from './normalization.js'
|
||||
|
||||
/*
|
||||
* Extracts MCP server information from a tool name string
|
||||
* @param toolString The string to parse. Expected format: "mcp__serverName__toolName"
|
||||
* @returns An object containing server name and optional tool name, or null if not a valid MCP rule
|
||||
*
|
||||
* Known limitation: If a server name contains "__", parsing will be incorrect.
|
||||
* For example, "mcp__my__server__tool" would parse as server="my" and tool="server__tool"
|
||||
* instead of server="my__server" and tool="tool". This is rare in practice since server
|
||||
* names typically don't contain double underscores.
|
||||
*/
|
||||
export function mcpInfoFromString(toolString: string): {
|
||||
serverName: string
|
||||
toolName: string | undefined
|
||||
} | null {
|
||||
const parts = toolString.split('__')
|
||||
const [mcpPart, serverName, ...toolNameParts] = parts
|
||||
if (mcpPart !== 'mcp' || !serverName) {
|
||||
return null
|
||||
}
|
||||
// Join all parts after server name to preserve double underscores in tool names
|
||||
const toolName =
|
||||
toolNameParts.length > 0 ? toolNameParts.join('__') : undefined
|
||||
return { serverName, toolName }
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates the MCP tool/command name prefix for a given server
|
||||
* @param serverName Name of the MCP server
|
||||
* @returns The prefix string
|
||||
*/
|
||||
export function getMcpPrefix(serverName: string): string {
|
||||
return `mcp__${normalizeNameForMCP(serverName)}__`
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds a fully qualified MCP tool name from server and tool names.
|
||||
* Inverse of mcpInfoFromString().
|
||||
* @param serverName Name of the MCP server (unnormalized)
|
||||
* @param toolName Name of the tool (unnormalized)
|
||||
* @returns The fully qualified name, e.g., "mcp__server__tool"
|
||||
*/
|
||||
export function buildMcpToolName(serverName: string, toolName: string): string {
|
||||
return `${getMcpPrefix(serverName)}${normalizeNameForMCP(toolName)}`
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the name to use for permission rule matching.
|
||||
* For MCP tools, uses the fully qualified mcp__server__tool name so that
|
||||
* deny rules targeting builtins (e.g., "Write") don't match unprefixed MCP
|
||||
* replacements that share the same display name. Falls back to `tool.name`.
|
||||
*/
|
||||
export function getToolNameForPermissionCheck(tool: {
|
||||
name: string
|
||||
mcpInfo?: { serverName: string; toolName: string }
|
||||
}): string {
|
||||
return tool.mcpInfo
|
||||
? buildMcpToolName(tool.mcpInfo.serverName, tool.mcpInfo.toolName)
|
||||
: tool.name
|
||||
}
|
||||
|
||||
/*
|
||||
* Extracts the display name from an MCP tool/command name
|
||||
* @param fullName The full MCP tool/command name (e.g., "mcp__server_name__tool_name")
|
||||
* @param serverName The server name to remove from the prefix
|
||||
* @returns The display name without the MCP prefix
|
||||
*/
|
||||
export function getMcpDisplayName(
|
||||
fullName: string,
|
||||
serverName: string,
|
||||
): string {
|
||||
const prefix = `mcp__${normalizeNameForMCP(serverName)}__`
|
||||
return fullName.replace(prefix, '')
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts just the tool/command display name from a userFacingName
|
||||
* @param userFacingName The full user-facing name (e.g., "github - Add comment to issue (MCP)")
|
||||
* @returns The display name without server prefix and (MCP) suffix
|
||||
*/
|
||||
export function extractMcpToolDisplayName(userFacingName: string): string {
|
||||
// This is really ugly but our current Tool type doesn't make it easy to have different display names for different purposes.
|
||||
|
||||
// First, remove the (MCP) suffix if present
|
||||
let withoutSuffix = userFacingName.replace(/\s*\(MCP\)\s*$/, '')
|
||||
|
||||
// Trim the result
|
||||
withoutSuffix = withoutSuffix.trim()
|
||||
|
||||
// Then, remove the server prefix (everything before " - ")
|
||||
const dashIndex = withoutSuffix.indexOf(' - ')
|
||||
if (dashIndex !== -1) {
|
||||
const displayName = withoutSuffix.substring(dashIndex + 3).trim()
|
||||
return displayName
|
||||
}
|
||||
|
||||
// If no dash found, return the string without (MCP)
|
||||
return withoutSuffix
|
||||
}
|
||||
23
src/services/mcp/normalization.ts
Normal file
23
src/services/mcp/normalization.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
/**
|
||||
* Pure utility functions for MCP name normalization.
|
||||
* This file has no dependencies to avoid circular imports.
|
||||
*/
|
||||
|
||||
// Claude.ai server names are prefixed with this string
|
||||
const CLAUDEAI_SERVER_PREFIX = 'claude.ai '
|
||||
|
||||
/**
|
||||
* Normalize server names to be compatible with the API pattern ^[a-zA-Z0-9_-]{1,64}$
|
||||
* Replaces any invalid characters (including dots and spaces) with underscores.
|
||||
*
|
||||
* For claude.ai servers (names starting with "claude.ai "), also collapses
|
||||
* consecutive underscores and strips leading/trailing underscores to prevent
|
||||
* interference with the __ delimiter used in MCP tool names.
|
||||
*/
|
||||
export function normalizeNameForMCP(name: string): string {
|
||||
let normalized = name.replace(/[^a-zA-Z0-9_-]/g, '_')
|
||||
if (name.startsWith(CLAUDEAI_SERVER_PREFIX)) {
|
||||
normalized = normalized.replace(/_+/g, '_').replace(/^_|_$/g, '')
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
78
src/services/mcp/oauthPort.ts
Normal file
78
src/services/mcp/oauthPort.ts
Normal file
@@ -0,0 +1,78 @@
|
||||
/**
|
||||
* OAuth redirect port helpers — extracted from auth.ts to break the
|
||||
* auth.ts ↔ xaaIdpLogin.ts circular dependency.
|
||||
*/
|
||||
import { createServer } from 'http'
|
||||
import { getPlatform } from '../../utils/platform.js'
|
||||
|
||||
// Windows dynamic port range 49152-65535 is reserved
|
||||
const REDIRECT_PORT_RANGE =
|
||||
getPlatform() === 'windows'
|
||||
? { min: 39152, max: 49151 }
|
||||
: { min: 49152, max: 65535 }
|
||||
const REDIRECT_PORT_FALLBACK = 3118
|
||||
|
||||
/**
|
||||
* Builds a redirect URI on localhost with the given port and a fixed `/callback` path.
|
||||
*
|
||||
* RFC 8252 Section 7.3 (OAuth for Native Apps): loopback redirect URIs match any
|
||||
* port as long as the path matches.
|
||||
*/
|
||||
export function buildRedirectUri(
|
||||
port: number = REDIRECT_PORT_FALLBACK,
|
||||
): string {
|
||||
return `http://localhost:${port}/callback`
|
||||
}
|
||||
|
||||
function getMcpOAuthCallbackPort(): number | undefined {
|
||||
const port = parseInt(process.env.MCP_OAUTH_CALLBACK_PORT || '', 10)
|
||||
return port > 0 ? port : undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds an available port in the specified range for OAuth redirect
|
||||
* Uses random selection for better security
|
||||
*/
|
||||
export async function findAvailablePort(): Promise<number> {
|
||||
// First, try the configured port if specified
|
||||
const configuredPort = getMcpOAuthCallbackPort()
|
||||
if (configuredPort) {
|
||||
return configuredPort
|
||||
}
|
||||
|
||||
const { min, max } = REDIRECT_PORT_RANGE
|
||||
const range = max - min + 1
|
||||
const maxAttempts = Math.min(range, 100) // Don't try forever
|
||||
|
||||
for (let attempt = 0; attempt < maxAttempts; attempt++) {
|
||||
const port = min + Math.floor(Math.random() * range)
|
||||
|
||||
try {
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
const testServer = createServer()
|
||||
testServer.once('error', reject)
|
||||
testServer.listen(port, () => {
|
||||
testServer.close(() => resolve())
|
||||
})
|
||||
})
|
||||
return port
|
||||
} catch {
|
||||
// Port in use, try another random port
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// If random selection failed, try the fallback port
|
||||
try {
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
const testServer = createServer()
|
||||
testServer.once('error', reject)
|
||||
testServer.listen(REDIRECT_PORT_FALLBACK, () => {
|
||||
testServer.close(() => resolve())
|
||||
})
|
||||
})
|
||||
return REDIRECT_PORT_FALLBACK
|
||||
} catch {
|
||||
throw new Error(`No available ports for OAuth redirect`)
|
||||
}
|
||||
}
|
||||
72
src/services/mcp/officialRegistry.ts
Normal file
72
src/services/mcp/officialRegistry.ts
Normal file
@@ -0,0 +1,72 @@
|
||||
import axios from 'axios'
|
||||
import { logForDebugging } from '../../utils/debug.js'
|
||||
import { errorMessage } from '../../utils/errors.js'
|
||||
|
||||
type RegistryServer = {
|
||||
server: {
|
||||
remotes?: Array<{ url: string }>
|
||||
}
|
||||
}
|
||||
|
||||
type RegistryResponse = {
|
||||
servers: RegistryServer[]
|
||||
}
|
||||
|
||||
// URLs stripped of query string and trailing slash — matches the normalization
|
||||
// done by getLoggingSafeMcpBaseUrl so direct Set.has() lookup works.
|
||||
let officialUrls: Set<string> | undefined = undefined
|
||||
|
||||
function normalizeUrl(url: string): string | undefined {
|
||||
try {
|
||||
const u = new URL(url)
|
||||
u.search = ''
|
||||
return u.toString().replace(/\/$/, '')
|
||||
} catch {
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fire-and-forget fetch of the official MCP registry.
|
||||
* Populates officialUrls for isOfficialMcpUrl lookups.
|
||||
*/
|
||||
export async function prefetchOfficialMcpUrls(): Promise<void> {
|
||||
if (process.env.CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await axios.get<RegistryResponse>(
|
||||
'https://api.anthropic.com/mcp-registry/v0/servers?version=latest&visibility=commercial',
|
||||
{ timeout: 5000 },
|
||||
)
|
||||
|
||||
const urls = new Set<string>()
|
||||
for (const entry of response.data.servers) {
|
||||
for (const remote of entry.server.remotes ?? []) {
|
||||
const normalized = normalizeUrl(remote.url)
|
||||
if (normalized) {
|
||||
urls.add(normalized)
|
||||
}
|
||||
}
|
||||
}
|
||||
officialUrls = urls
|
||||
logForDebugging(`[mcp-registry] Loaded ${urls.size} official MCP URLs`)
|
||||
} catch (error) {
|
||||
logForDebugging(`Failed to fetch MCP registry: ${errorMessage(error)}`, {
|
||||
level: 'error',
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true iff the given (already-normalized via getLoggingSafeMcpBaseUrl)
|
||||
* URL is in the official MCP registry. Undefined registry → false (fail-closed).
|
||||
*/
|
||||
export function isOfficialMcpUrl(normalizedUrl: string): boolean {
|
||||
return officialUrls?.has(normalizedUrl) ?? false
|
||||
}
|
||||
|
||||
export function resetOfficialMcpUrlsForTesting(): void {
|
||||
officialUrls = undefined
|
||||
}
|
||||
258
src/services/mcp/types.ts
Normal file
258
src/services/mcp/types.ts
Normal file
@@ -0,0 +1,258 @@
|
||||
import type { Client } from '@modelcontextprotocol/sdk/client/index.js'
|
||||
import type {
|
||||
Resource,
|
||||
ServerCapabilities,
|
||||
} from '@modelcontextprotocol/sdk/types.js'
|
||||
import { z } from 'zod/v4'
|
||||
import { lazySchema } from '../../utils/lazySchema.js'
|
||||
|
||||
// Configuration schemas and types
|
||||
export const ConfigScopeSchema = lazySchema(() =>
|
||||
z.enum([
|
||||
'local',
|
||||
'user',
|
||||
'project',
|
||||
'dynamic',
|
||||
'enterprise',
|
||||
'claudeai',
|
||||
'managed',
|
||||
]),
|
||||
)
|
||||
export type ConfigScope = z.infer<ReturnType<typeof ConfigScopeSchema>>
|
||||
|
||||
export const TransportSchema = lazySchema(() =>
|
||||
z.enum(['stdio', 'sse', 'sse-ide', 'http', 'ws', 'sdk']),
|
||||
)
|
||||
export type Transport = z.infer<ReturnType<typeof TransportSchema>>
|
||||
|
||||
export const McpStdioServerConfigSchema = lazySchema(() =>
|
||||
z.object({
|
||||
type: z.literal('stdio').optional(), // Optional for backwards compatibility
|
||||
command: z.string().min(1, 'Command cannot be empty'),
|
||||
args: z.array(z.string()).default([]),
|
||||
env: z.record(z.string(), z.string()).optional(),
|
||||
}),
|
||||
)
|
||||
|
||||
// Cross-App Access (XAA / SEP-990): just a per-server flag. IdP connection
|
||||
// details (issuer, clientId, callbackPort) come from settings.xaaIdp — configured
|
||||
// once, shared across all XAA-enabled servers. clientId/clientSecret (parent
|
||||
// oauth config + keychain slot) are for the MCP server's AS.
|
||||
const McpXaaConfigSchema = lazySchema(() => z.boolean())
|
||||
|
||||
const McpOAuthConfigSchema = lazySchema(() =>
|
||||
z.object({
|
||||
clientId: z.string().optional(),
|
||||
callbackPort: z.number().int().positive().optional(),
|
||||
authServerMetadataUrl: z
|
||||
.string()
|
||||
.url()
|
||||
.startsWith('https://', {
|
||||
message: 'authServerMetadataUrl must use https://',
|
||||
})
|
||||
.optional(),
|
||||
xaa: McpXaaConfigSchema().optional(),
|
||||
}),
|
||||
)
|
||||
|
||||
export const McpSSEServerConfigSchema = lazySchema(() =>
|
||||
z.object({
|
||||
type: z.literal('sse'),
|
||||
url: z.string(),
|
||||
headers: z.record(z.string(), z.string()).optional(),
|
||||
headersHelper: z.string().optional(),
|
||||
oauth: McpOAuthConfigSchema().optional(),
|
||||
}),
|
||||
)
|
||||
|
||||
// Internal-only server type for IDE extensions
|
||||
export const McpSSEIDEServerConfigSchema = lazySchema(() =>
|
||||
z.object({
|
||||
type: z.literal('sse-ide'),
|
||||
url: z.string(),
|
||||
ideName: z.string(),
|
||||
ideRunningInWindows: z.boolean().optional(),
|
||||
}),
|
||||
)
|
||||
|
||||
// Internal-only server type for IDE extensions
|
||||
export const McpWebSocketIDEServerConfigSchema = lazySchema(() =>
|
||||
z.object({
|
||||
type: z.literal('ws-ide'),
|
||||
url: z.string(),
|
||||
ideName: z.string(),
|
||||
authToken: z.string().optional(),
|
||||
ideRunningInWindows: z.boolean().optional(),
|
||||
}),
|
||||
)
|
||||
|
||||
export const McpHTTPServerConfigSchema = lazySchema(() =>
|
||||
z.object({
|
||||
type: z.literal('http'),
|
||||
url: z.string(),
|
||||
headers: z.record(z.string(), z.string()).optional(),
|
||||
headersHelper: z.string().optional(),
|
||||
oauth: McpOAuthConfigSchema().optional(),
|
||||
}),
|
||||
)
|
||||
|
||||
export const McpWebSocketServerConfigSchema = lazySchema(() =>
|
||||
z.object({
|
||||
type: z.literal('ws'),
|
||||
url: z.string(),
|
||||
headers: z.record(z.string(), z.string()).optional(),
|
||||
headersHelper: z.string().optional(),
|
||||
}),
|
||||
)
|
||||
|
||||
export const McpSdkServerConfigSchema = lazySchema(() =>
|
||||
z.object({
|
||||
type: z.literal('sdk'),
|
||||
name: z.string(),
|
||||
}),
|
||||
)
|
||||
|
||||
// Config type for Claude.ai proxy servers
|
||||
export const McpClaudeAIProxyServerConfigSchema = lazySchema(() =>
|
||||
z.object({
|
||||
type: z.literal('claudeai-proxy'),
|
||||
url: z.string(),
|
||||
id: z.string(),
|
||||
}),
|
||||
)
|
||||
|
||||
export const McpServerConfigSchema = lazySchema(() =>
|
||||
z.union([
|
||||
McpStdioServerConfigSchema(),
|
||||
McpSSEServerConfigSchema(),
|
||||
McpSSEIDEServerConfigSchema(),
|
||||
McpWebSocketIDEServerConfigSchema(),
|
||||
McpHTTPServerConfigSchema(),
|
||||
McpWebSocketServerConfigSchema(),
|
||||
McpSdkServerConfigSchema(),
|
||||
McpClaudeAIProxyServerConfigSchema(),
|
||||
]),
|
||||
)
|
||||
|
||||
export type McpStdioServerConfig = z.infer<
|
||||
ReturnType<typeof McpStdioServerConfigSchema>
|
||||
>
|
||||
export type McpSSEServerConfig = z.infer<
|
||||
ReturnType<typeof McpSSEServerConfigSchema>
|
||||
>
|
||||
export type McpSSEIDEServerConfig = z.infer<
|
||||
ReturnType<typeof McpSSEIDEServerConfigSchema>
|
||||
>
|
||||
export type McpWebSocketIDEServerConfig = z.infer<
|
||||
ReturnType<typeof McpWebSocketIDEServerConfigSchema>
|
||||
>
|
||||
export type McpHTTPServerConfig = z.infer<
|
||||
ReturnType<typeof McpHTTPServerConfigSchema>
|
||||
>
|
||||
export type McpWebSocketServerConfig = z.infer<
|
||||
ReturnType<typeof McpWebSocketServerConfigSchema>
|
||||
>
|
||||
export type McpSdkServerConfig = z.infer<
|
||||
ReturnType<typeof McpSdkServerConfigSchema>
|
||||
>
|
||||
export type McpClaudeAIProxyServerConfig = z.infer<
|
||||
ReturnType<typeof McpClaudeAIProxyServerConfigSchema>
|
||||
>
|
||||
export type McpServerConfig = z.infer<ReturnType<typeof McpServerConfigSchema>>
|
||||
|
||||
export type ScopedMcpServerConfig = McpServerConfig & {
|
||||
scope: ConfigScope
|
||||
// For plugin-provided servers: the providing plugin's LoadedPlugin.source
|
||||
// (e.g. 'slack@anthropic'). Stashed at config-build time so the channel
|
||||
// gate doesn't have to race AppState.plugins.enabled hydration.
|
||||
pluginSource?: string
|
||||
}
|
||||
|
||||
export const McpJsonConfigSchema = lazySchema(() =>
|
||||
z.object({
|
||||
mcpServers: z.record(z.string(), McpServerConfigSchema()),
|
||||
}),
|
||||
)
|
||||
|
||||
export type McpJsonConfig = z.infer<ReturnType<typeof McpJsonConfigSchema>>
|
||||
|
||||
// Server connection types
|
||||
export type ConnectedMCPServer = {
|
||||
client: Client
|
||||
name: string
|
||||
type: 'connected'
|
||||
capabilities: ServerCapabilities
|
||||
serverInfo?: {
|
||||
name: string
|
||||
version: string
|
||||
}
|
||||
instructions?: string
|
||||
config: ScopedMcpServerConfig
|
||||
cleanup: () => Promise<void>
|
||||
}
|
||||
|
||||
export type FailedMCPServer = {
|
||||
name: string
|
||||
type: 'failed'
|
||||
config: ScopedMcpServerConfig
|
||||
error?: string
|
||||
}
|
||||
|
||||
export type NeedsAuthMCPServer = {
|
||||
name: string
|
||||
type: 'needs-auth'
|
||||
config: ScopedMcpServerConfig
|
||||
}
|
||||
|
||||
export type PendingMCPServer = {
|
||||
name: string
|
||||
type: 'pending'
|
||||
config: ScopedMcpServerConfig
|
||||
reconnectAttempt?: number
|
||||
maxReconnectAttempts?: number
|
||||
}
|
||||
|
||||
export type DisabledMCPServer = {
|
||||
name: string
|
||||
type: 'disabled'
|
||||
config: ScopedMcpServerConfig
|
||||
}
|
||||
|
||||
export type MCPServerConnection =
|
||||
| ConnectedMCPServer
|
||||
| FailedMCPServer
|
||||
| NeedsAuthMCPServer
|
||||
| PendingMCPServer
|
||||
| DisabledMCPServer
|
||||
|
||||
// Resource types
|
||||
export type ServerResource = Resource & { server: string }
|
||||
|
||||
// MCP CLI State types
|
||||
export interface SerializedTool {
|
||||
name: string
|
||||
description: string
|
||||
inputJSONSchema?: {
|
||||
[x: string]: unknown
|
||||
type: 'object'
|
||||
properties?: {
|
||||
[x: string]: unknown
|
||||
}
|
||||
}
|
||||
isMcp?: boolean
|
||||
originalToolName?: string // Original unnormalized tool name from MCP server
|
||||
}
|
||||
|
||||
export interface SerializedClient {
|
||||
name: string
|
||||
type: 'connected' | 'failed' | 'needs-auth' | 'pending' | 'disabled'
|
||||
capabilities?: ServerCapabilities
|
||||
}
|
||||
|
||||
export interface MCPCliState {
|
||||
clients: SerializedClient[]
|
||||
configs: Record<string, ScopedMcpServerConfig>
|
||||
tools: SerializedTool[]
|
||||
resources: Record<string, ServerResource[]>
|
||||
normalizedNames?: Record<string, string> // Maps normalized names to original names
|
||||
}
|
||||
1141
src/services/mcp/useManageMCPConnections.ts
Normal file
1141
src/services/mcp/useManageMCPConnections.ts
Normal file
File diff suppressed because it is too large
Load Diff
575
src/services/mcp/utils.ts
Normal file
575
src/services/mcp/utils.ts
Normal file
@@ -0,0 +1,575 @@
|
||||
import { createHash } from 'crypto'
|
||||
import { join } from 'path'
|
||||
import { getIsNonInteractiveSession } from '../../bootstrap/state.js'
|
||||
import type { Command } from '../../commands.js'
|
||||
import type { AgentMcpServerInfo } from '../../components/mcp/types.js'
|
||||
import type { Tool } from '../../Tool.js'
|
||||
import type { AgentDefinition } from '../../tools/AgentTool/loadAgentsDir.js'
|
||||
import { getCwd } from '../../utils/cwd.js'
|
||||
import { getGlobalClaudeFile } from '../../utils/env.js'
|
||||
import { isSettingSourceEnabled } from '../../utils/settings/constants.js'
|
||||
import {
|
||||
getSettings_DEPRECATED,
|
||||
hasSkipDangerousModePermissionPrompt,
|
||||
} from '../../utils/settings/settings.js'
|
||||
import { jsonStringify } from '../../utils/slowOperations.js'
|
||||
import { getEnterpriseMcpFilePath, getMcpConfigByName } from './config.js'
|
||||
import { mcpInfoFromString } from './mcpStringUtils.js'
|
||||
import { normalizeNameForMCP } from './normalization.js'
|
||||
import {
|
||||
type ConfigScope,
|
||||
ConfigScopeSchema,
|
||||
type MCPServerConnection,
|
||||
type McpHTTPServerConfig,
|
||||
type McpServerConfig,
|
||||
type McpSSEServerConfig,
|
||||
type McpStdioServerConfig,
|
||||
type McpWebSocketServerConfig,
|
||||
type ScopedMcpServerConfig,
|
||||
type ServerResource,
|
||||
} from './types.js'
|
||||
|
||||
/**
|
||||
* Filters tools by MCP server name
|
||||
*
|
||||
* @param tools Array of tools to filter
|
||||
* @param serverName Name of the MCP server
|
||||
* @returns Tools belonging to the specified server
|
||||
*/
|
||||
export function filterToolsByServer(tools: Tool[], serverName: string): Tool[] {
|
||||
const prefix = `mcp__${normalizeNameForMCP(serverName)}__`
|
||||
return tools.filter(tool => tool.name?.startsWith(prefix))
|
||||
}
|
||||
|
||||
/**
|
||||
* True when a command belongs to the given MCP server.
|
||||
*
|
||||
* MCP **prompts** are named `mcp__<server>__<prompt>` (wire-format constraint);
|
||||
* MCP **skills** are named `<server>:<skill>` (matching plugin/nested-dir skill
|
||||
* naming). Both live in `mcp.commands`, so cleanup and filtering must match
|
||||
* either shape.
|
||||
*/
|
||||
export function commandBelongsToServer(
|
||||
command: Command,
|
||||
serverName: string,
|
||||
): boolean {
|
||||
const normalized = normalizeNameForMCP(serverName)
|
||||
const name = command.name
|
||||
if (!name) return false
|
||||
return (
|
||||
name.startsWith(`mcp__${normalized}__`) || name.startsWith(`${normalized}:`)
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Filters commands by MCP server name
|
||||
* @param commands Array of commands to filter
|
||||
* @param serverName Name of the MCP server
|
||||
* @returns Commands belonging to the specified server
|
||||
*/
|
||||
export function filterCommandsByServer(
|
||||
commands: Command[],
|
||||
serverName: string,
|
||||
): Command[] {
|
||||
return commands.filter(c => commandBelongsToServer(c, serverName))
|
||||
}
|
||||
|
||||
/**
|
||||
* Filters MCP **prompts** (not skills) by server. Used by the `/mcp` menu
|
||||
* capabilities display — skills are a separate feature shown in `/skills`,
|
||||
* so they mustn't inflate the "prompts" capability badge.
|
||||
*
|
||||
* The distinguisher is `loadedFrom === 'mcp'`: MCP skills set it, MCP
|
||||
* prompts don't (they use `isMcp: true` instead).
|
||||
*/
|
||||
export function filterMcpPromptsByServer(
|
||||
commands: Command[],
|
||||
serverName: string,
|
||||
): Command[] {
|
||||
return commands.filter(
|
||||
c =>
|
||||
commandBelongsToServer(c, serverName) &&
|
||||
!(c.type === 'prompt' && c.loadedFrom === 'mcp'),
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Filters resources by MCP server name
|
||||
* @param resources Array of resources to filter
|
||||
* @param serverName Name of the MCP server
|
||||
* @returns Resources belonging to the specified server
|
||||
*/
|
||||
export function filterResourcesByServer(
|
||||
resources: ServerResource[],
|
||||
serverName: string,
|
||||
): ServerResource[] {
|
||||
return resources.filter(resource => resource.server === serverName)
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes tools belonging to a specific MCP server
|
||||
* @param tools Array of tools
|
||||
* @param serverName Name of the MCP server to exclude
|
||||
* @returns Tools not belonging to the specified server
|
||||
*/
|
||||
export function excludeToolsByServer(
|
||||
tools: Tool[],
|
||||
serverName: string,
|
||||
): Tool[] {
|
||||
const prefix = `mcp__${normalizeNameForMCP(serverName)}__`
|
||||
return tools.filter(tool => !tool.name?.startsWith(prefix))
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes commands belonging to a specific MCP server
|
||||
* @param commands Array of commands
|
||||
* @param serverName Name of the MCP server to exclude
|
||||
* @returns Commands not belonging to the specified server
|
||||
*/
|
||||
export function excludeCommandsByServer(
|
||||
commands: Command[],
|
||||
serverName: string,
|
||||
): Command[] {
|
||||
return commands.filter(c => !commandBelongsToServer(c, serverName))
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes resources belonging to a specific MCP server
|
||||
* @param resources Map of server resources
|
||||
* @param serverName Name of the MCP server to exclude
|
||||
* @returns Resources map without the specified server
|
||||
*/
|
||||
export function excludeResourcesByServer(
|
||||
resources: Record<string, ServerResource[]>,
|
||||
serverName: string,
|
||||
): Record<string, ServerResource[]> {
|
||||
const result = { ...resources }
|
||||
delete result[serverName]
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* Stable hash of an MCP server config for change detection on /reload-plugins.
|
||||
* Excludes `scope` (provenance, not content — moving a server from .mcp.json
|
||||
* to settings.json shouldn't reconnect it). Keys sorted so `{a:1,b:2}` and
|
||||
* `{b:2,a:1}` hash the same.
|
||||
*/
|
||||
export function hashMcpConfig(config: ScopedMcpServerConfig): string {
|
||||
const { scope: _scope, ...rest } = config
|
||||
const stable = jsonStringify(rest, (_k, v: unknown) => {
|
||||
if (v && typeof v === 'object' && !Array.isArray(v)) {
|
||||
const obj = v as Record<string, unknown>
|
||||
const sorted: Record<string, unknown> = {}
|
||||
for (const k of Object.keys(obj).sort()) sorted[k] = obj[k]
|
||||
return sorted
|
||||
}
|
||||
return v
|
||||
})
|
||||
return createHash('sha256').update(stable).digest('hex').slice(0, 16)
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove stale MCP clients and their tools/commands/resources. A client is
|
||||
* stale if:
|
||||
* - scope 'dynamic' and name no longer in configs (plugin disabled), or
|
||||
* - config hash changed (args/url/env edited in .mcp.json) — any scope
|
||||
*
|
||||
* The removal case is scoped to 'dynamic' so /reload-plugins can't
|
||||
* accidentally disconnect a user-configured server that's just temporarily
|
||||
* absent from the in-memory config (e.g. during a partial reload). The
|
||||
* config-changed case applies to all scopes — if the config actually changed
|
||||
* on disk, reconnecting is what you want.
|
||||
*
|
||||
* Returns the stale clients so the caller can disconnect them (clearServerCache).
|
||||
*/
|
||||
export function excludeStalePluginClients(
|
||||
mcp: {
|
||||
clients: MCPServerConnection[]
|
||||
tools: Tool[]
|
||||
commands: Command[]
|
||||
resources: Record<string, ServerResource[]>
|
||||
},
|
||||
configs: Record<string, ScopedMcpServerConfig>,
|
||||
): {
|
||||
clients: MCPServerConnection[]
|
||||
tools: Tool[]
|
||||
commands: Command[]
|
||||
resources: Record<string, ServerResource[]>
|
||||
stale: MCPServerConnection[]
|
||||
} {
|
||||
const stale = mcp.clients.filter(c => {
|
||||
const fresh = configs[c.name]
|
||||
if (!fresh) return c.config.scope === 'dynamic'
|
||||
return hashMcpConfig(c.config) !== hashMcpConfig(fresh)
|
||||
})
|
||||
if (stale.length === 0) {
|
||||
return { ...mcp, stale: [] }
|
||||
}
|
||||
|
||||
let { tools, commands, resources } = mcp
|
||||
for (const s of stale) {
|
||||
tools = excludeToolsByServer(tools, s.name)
|
||||
commands = excludeCommandsByServer(commands, s.name)
|
||||
resources = excludeResourcesByServer(resources, s.name)
|
||||
}
|
||||
const staleNames = new Set(stale.map(c => c.name))
|
||||
|
||||
return {
|
||||
clients: mcp.clients.filter(c => !staleNames.has(c.name)),
|
||||
tools,
|
||||
commands,
|
||||
resources,
|
||||
stale,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a tool name belongs to a specific MCP server
|
||||
* @param toolName The tool name to check
|
||||
* @param serverName The server name to match against
|
||||
* @returns True if the tool belongs to the specified server
|
||||
*/
|
||||
export function isToolFromMcpServer(
|
||||
toolName: string,
|
||||
serverName: string,
|
||||
): boolean {
|
||||
const info = mcpInfoFromString(toolName)
|
||||
return info?.serverName === serverName
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a tool belongs to any MCP server
|
||||
* @param tool The tool to check
|
||||
* @returns True if the tool is from an MCP server
|
||||
*/
|
||||
export function isMcpTool(tool: Tool): boolean {
|
||||
return tool.name?.startsWith('mcp__') || tool.isMcp === true
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a command belongs to any MCP server
|
||||
* @param command The command to check
|
||||
* @returns True if the command is from an MCP server
|
||||
*/
|
||||
export function isMcpCommand(command: Command): boolean {
|
||||
return command.name?.startsWith('mcp__') || command.isMcp === true
|
||||
}
|
||||
|
||||
/**
|
||||
* Describe the file path for a given MCP config scope.
|
||||
* @param scope The config scope ('user', 'project', 'local', or 'dynamic')
|
||||
* @returns A description of where the config is stored
|
||||
*/
|
||||
export function describeMcpConfigFilePath(scope: ConfigScope): string {
|
||||
switch (scope) {
|
||||
case 'user':
|
||||
return getGlobalClaudeFile()
|
||||
case 'project':
|
||||
return join(getCwd(), '.mcp.json')
|
||||
case 'local':
|
||||
return `${getGlobalClaudeFile()} [project: ${getCwd()}]`
|
||||
case 'dynamic':
|
||||
return 'Dynamically configured'
|
||||
case 'enterprise':
|
||||
return getEnterpriseMcpFilePath()
|
||||
case 'claudeai':
|
||||
return 'claude.ai'
|
||||
default:
|
||||
return scope
|
||||
}
|
||||
}
|
||||
|
||||
export function getScopeLabel(scope: ConfigScope): string {
|
||||
switch (scope) {
|
||||
case 'local':
|
||||
return 'Local config (private to you in this project)'
|
||||
case 'project':
|
||||
return 'Project config (shared via .mcp.json)'
|
||||
case 'user':
|
||||
return 'User config (available in all your projects)'
|
||||
case 'dynamic':
|
||||
return 'Dynamic config (from command line)'
|
||||
case 'enterprise':
|
||||
return 'Enterprise config (managed by your organization)'
|
||||
case 'claudeai':
|
||||
return 'claude.ai config'
|
||||
default:
|
||||
return scope
|
||||
}
|
||||
}
|
||||
|
||||
export function ensureConfigScope(scope?: string): ConfigScope {
|
||||
if (!scope) return 'local'
|
||||
|
||||
if (!ConfigScopeSchema().options.includes(scope as ConfigScope)) {
|
||||
throw new Error(
|
||||
`Invalid scope: ${scope}. Must be one of: ${ConfigScopeSchema().options.join(', ')}`,
|
||||
)
|
||||
}
|
||||
|
||||
return scope as ConfigScope
|
||||
}
|
||||
|
||||
export function ensureTransport(type?: string): 'stdio' | 'sse' | 'http' {
|
||||
if (!type) return 'stdio'
|
||||
|
||||
if (type !== 'stdio' && type !== 'sse' && type !== 'http') {
|
||||
throw new Error(
|
||||
`Invalid transport type: ${type}. Must be one of: stdio, sse, http`,
|
||||
)
|
||||
}
|
||||
|
||||
return type as 'stdio' | 'sse' | 'http'
|
||||
}
|
||||
|
||||
export function parseHeaders(headerArray: string[]): Record<string, string> {
|
||||
const headers: Record<string, string> = {}
|
||||
|
||||
for (const header of headerArray) {
|
||||
const colonIndex = header.indexOf(':')
|
||||
if (colonIndex === -1) {
|
||||
throw new Error(
|
||||
`Invalid header format: "${header}". Expected format: "Header-Name: value"`,
|
||||
)
|
||||
}
|
||||
|
||||
const key = header.substring(0, colonIndex).trim()
|
||||
const value = header.substring(colonIndex + 1).trim()
|
||||
|
||||
if (!key) {
|
||||
throw new Error(
|
||||
`Invalid header: "${header}". Header name cannot be empty.`,
|
||||
)
|
||||
}
|
||||
|
||||
headers[key] = value
|
||||
}
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
export function getProjectMcpServerStatus(
|
||||
serverName: string,
|
||||
): 'approved' | 'rejected' | 'pending' {
|
||||
const settings = getSettings_DEPRECATED()
|
||||
const normalizedName = normalizeNameForMCP(serverName)
|
||||
|
||||
// TODO: This fails an e2e test if the ?. is not present. This is likely a bug in the e2e test.
|
||||
// Will fix this in a follow-up PR.
|
||||
if (
|
||||
settings?.disabledMcpjsonServers?.some(
|
||||
name => normalizeNameForMCP(name) === normalizedName,
|
||||
)
|
||||
) {
|
||||
return 'rejected'
|
||||
}
|
||||
|
||||
if (
|
||||
settings?.enabledMcpjsonServers?.some(
|
||||
name => normalizeNameForMCP(name) === normalizedName,
|
||||
) ||
|
||||
settings?.enableAllProjectMcpServers
|
||||
) {
|
||||
return 'approved'
|
||||
}
|
||||
|
||||
// In bypass permissions mode (--dangerously-skip-permissions), there's no way
|
||||
// to show an approval popup. Auto-approve if projectSettings is enabled since
|
||||
// the user has explicitly chosen to bypass all permission checks.
|
||||
// SECURITY: We intentionally only check skipDangerousModePermissionPrompt via
|
||||
// hasSkipDangerousModePermissionPrompt(), which reads from userSettings/localSettings/
|
||||
// flagSettings/policySettings but NOT projectSettings (repo-level .claude/settings.json).
|
||||
// This is intentional: a repo should not be able to accept the bypass dialog on behalf of
|
||||
// users. We also do NOT check getSessionBypassPermissionsMode() here because
|
||||
// sessionBypassPermissionsMode can be set from project settings before the dialog is shown,
|
||||
// which would allow RCE attacks via malicious project settings.
|
||||
if (
|
||||
hasSkipDangerousModePermissionPrompt() &&
|
||||
isSettingSourceEnabled('projectSettings')
|
||||
) {
|
||||
return 'approved'
|
||||
}
|
||||
|
||||
// In non-interactive mode (SDK, claude -p, piped input), there's no way to
|
||||
// show an approval popup. Auto-approve if projectSettings is enabled since:
|
||||
// 1. The user/developer explicitly chose to run in this mode
|
||||
// 2. For SDK, projectSettings is off by default - they must explicitly enable it
|
||||
// 3. For -p mode, the help text warns to only use in trusted directories
|
||||
if (
|
||||
getIsNonInteractiveSession() &&
|
||||
isSettingSourceEnabled('projectSettings')
|
||||
) {
|
||||
return 'approved'
|
||||
}
|
||||
|
||||
return 'pending'
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the scope/settings source for an MCP server from a tool name
|
||||
* @param toolName MCP tool name (format: mcp__serverName__toolName)
|
||||
* @returns ConfigScope or null if not an MCP tool or server not found
|
||||
*/
|
||||
export function getMcpServerScopeFromToolName(
|
||||
toolName: string,
|
||||
): ConfigScope | null {
|
||||
if (!isMcpTool({ name: toolName } as Tool)) {
|
||||
return null
|
||||
}
|
||||
|
||||
// Extract server name from tool name (format: mcp__serverName__toolName)
|
||||
const mcpInfo = mcpInfoFromString(toolName)
|
||||
if (!mcpInfo) {
|
||||
return null
|
||||
}
|
||||
|
||||
// Look up server config
|
||||
const serverConfig = getMcpConfigByName(mcpInfo.serverName)
|
||||
|
||||
// Fallback: claude.ai servers have normalized names starting with "claude_ai_"
|
||||
// but aren't in getMcpConfigByName (they're fetched async separately)
|
||||
if (!serverConfig && mcpInfo.serverName.startsWith('claude_ai_')) {
|
||||
return 'claudeai'
|
||||
}
|
||||
|
||||
return serverConfig?.scope ?? null
|
||||
}
|
||||
|
||||
// Type guards for MCP server config types
|
||||
function isStdioConfig(
|
||||
config: McpServerConfig,
|
||||
): config is McpStdioServerConfig {
|
||||
return config.type === 'stdio' || config.type === undefined
|
||||
}
|
||||
|
||||
function isSSEConfig(config: McpServerConfig): config is McpSSEServerConfig {
|
||||
return config.type === 'sse'
|
||||
}
|
||||
|
||||
function isHTTPConfig(config: McpServerConfig): config is McpHTTPServerConfig {
|
||||
return config.type === 'http'
|
||||
}
|
||||
|
||||
function isWebSocketConfig(
|
||||
config: McpServerConfig,
|
||||
): config is McpWebSocketServerConfig {
|
||||
return config.type === 'ws'
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts MCP server definitions from agent frontmatter and groups them by server name.
|
||||
* This is used to show agent-specific MCP servers in the /mcp command.
|
||||
*
|
||||
* @param agents Array of agent definitions
|
||||
* @returns Array of AgentMcpServerInfo, grouped by server name with list of source agents
|
||||
*/
|
||||
export function extractAgentMcpServers(
|
||||
agents: AgentDefinition[],
|
||||
): AgentMcpServerInfo[] {
|
||||
// Map: server name -> { config, sourceAgents }
|
||||
const serverMap = new Map<
|
||||
string,
|
||||
{
|
||||
config: McpServerConfig & { name: string }
|
||||
sourceAgents: string[]
|
||||
}
|
||||
>()
|
||||
|
||||
for (const agent of agents) {
|
||||
if (!agent.mcpServers?.length) continue
|
||||
|
||||
for (const spec of agent.mcpServers) {
|
||||
// Skip string references - these refer to servers already in global config
|
||||
if (typeof spec === 'string') continue
|
||||
|
||||
// Inline definition as { [name]: config }
|
||||
const entries = Object.entries(spec)
|
||||
if (entries.length !== 1) continue
|
||||
|
||||
const [serverName, serverConfig] = entries[0]!
|
||||
const existing = serverMap.get(serverName)
|
||||
|
||||
if (existing) {
|
||||
// Add this agent as another source
|
||||
if (!existing.sourceAgents.includes(agent.agentType)) {
|
||||
existing.sourceAgents.push(agent.agentType)
|
||||
}
|
||||
} else {
|
||||
// New server
|
||||
serverMap.set(serverName, {
|
||||
config: { ...serverConfig, name: serverName } as McpServerConfig & {
|
||||
name: string
|
||||
},
|
||||
sourceAgents: [agent.agentType],
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert map to array of AgentMcpServerInfo
|
||||
// Only include transport types supported by AgentMcpServerInfo
|
||||
const result: AgentMcpServerInfo[] = []
|
||||
for (const [name, { config, sourceAgents }] of serverMap) {
|
||||
// Use type guards to properly narrow the discriminated union type
|
||||
// Only include transport types that are supported by AgentMcpServerInfo
|
||||
if (isStdioConfig(config)) {
|
||||
result.push({
|
||||
name,
|
||||
sourceAgents,
|
||||
transport: 'stdio',
|
||||
command: config.command,
|
||||
needsAuth: false,
|
||||
})
|
||||
} else if (isSSEConfig(config)) {
|
||||
result.push({
|
||||
name,
|
||||
sourceAgents,
|
||||
transport: 'sse',
|
||||
url: config.url,
|
||||
needsAuth: true,
|
||||
})
|
||||
} else if (isHTTPConfig(config)) {
|
||||
result.push({
|
||||
name,
|
||||
sourceAgents,
|
||||
transport: 'http',
|
||||
url: config.url,
|
||||
needsAuth: true,
|
||||
})
|
||||
} else if (isWebSocketConfig(config)) {
|
||||
result.push({
|
||||
name,
|
||||
sourceAgents,
|
||||
transport: 'ws',
|
||||
url: config.url,
|
||||
needsAuth: false,
|
||||
})
|
||||
}
|
||||
// Skip unsupported transport types (sdk, claudeai-proxy, sse-ide, ws-ide)
|
||||
// These are internal types not meant for agent MCP server display
|
||||
}
|
||||
|
||||
return result.sort((a, b) => a.name.localeCompare(b.name))
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts the MCP server base URL (without query string) for analytics logging.
|
||||
* Query strings are stripped because they can contain access tokens.
|
||||
* Trailing slashes are also removed for normalization.
|
||||
* Returns undefined for stdio/sdk servers or if URL parsing fails.
|
||||
*/
|
||||
export function getLoggingSafeMcpBaseUrl(
|
||||
config: McpServerConfig,
|
||||
): string | undefined {
|
||||
if (!('url' in config) || typeof config.url !== 'string') {
|
||||
return undefined
|
||||
}
|
||||
|
||||
try {
|
||||
const url = new URL(config.url)
|
||||
url.search = ''
|
||||
return url.toString().replace(/\/$/, '')
|
||||
} catch {
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
112
src/services/mcp/vscodeSdkMcp.ts
Normal file
112
src/services/mcp/vscodeSdkMcp.ts
Normal file
@@ -0,0 +1,112 @@
|
||||
import { logForDebugging } from 'src/utils/debug.js'
|
||||
import { z } from 'zod/v4'
|
||||
import { lazySchema } from '../../utils/lazySchema.js'
|
||||
import {
|
||||
checkStatsigFeatureGate_CACHED_MAY_BE_STALE,
|
||||
getFeatureValue_CACHED_MAY_BE_STALE,
|
||||
} from '../analytics/growthbook.js'
|
||||
import { logEvent } from '../analytics/index.js'
|
||||
import type { ConnectedMCPServer, MCPServerConnection } from './types.js'
|
||||
|
||||
// Mirror of AutoModeEnabledState in permissionSetup.ts — inlined because that
|
||||
// file pulls in too many deps for this thin IPC module.
|
||||
type AutoModeEnabledState = 'enabled' | 'disabled' | 'opt-in'
|
||||
function readAutoModeEnabledState(): AutoModeEnabledState | undefined {
|
||||
const v = getFeatureValue_CACHED_MAY_BE_STALE<{ enabled?: string }>(
|
||||
'tengu_auto_mode_config',
|
||||
{},
|
||||
)?.enabled
|
||||
return v === 'enabled' || v === 'disabled' || v === 'opt-in' ? v : undefined
|
||||
}
|
||||
|
||||
export const LogEventNotificationSchema = lazySchema(() =>
|
||||
z.object({
|
||||
method: z.literal('log_event'),
|
||||
params: z.object({
|
||||
eventName: z.string(),
|
||||
eventData: z.object({}).passthrough(),
|
||||
}),
|
||||
}),
|
||||
)
|
||||
|
||||
// Store the VSCode MCP client reference for sending notifications
|
||||
let vscodeMcpClient: ConnectedMCPServer | null = null
|
||||
|
||||
/**
|
||||
* Sends a file_updated notification to the VSCode MCP server. This is used to
|
||||
* notify VSCode when files are edited or written by Claude.
|
||||
*/
|
||||
export function notifyVscodeFileUpdated(
|
||||
filePath: string,
|
||||
oldContent: string | null,
|
||||
newContent: string | null,
|
||||
): void {
|
||||
if (process.env.USER_TYPE !== 'ant' || !vscodeMcpClient) {
|
||||
return
|
||||
}
|
||||
|
||||
void vscodeMcpClient.client
|
||||
.notification({
|
||||
method: 'file_updated',
|
||||
params: { filePath, oldContent, newContent },
|
||||
})
|
||||
.catch((error: Error) => {
|
||||
// Do not throw if the notification failed
|
||||
logForDebugging(
|
||||
`[VSCode] Failed to send file_updated notification: ${error.message}`,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets up the speicial internal VSCode MCP for bidirectional communication using notifications.
|
||||
*/
|
||||
export function setupVscodeSdkMcp(sdkClients: MCPServerConnection[]): void {
|
||||
const client = sdkClients.find(client => client.name === 'claude-vscode')
|
||||
|
||||
if (client && client.type === 'connected') {
|
||||
// Store the client reference for later use
|
||||
vscodeMcpClient = client
|
||||
|
||||
client.client.setNotificationHandler(
|
||||
LogEventNotificationSchema(),
|
||||
async notification => {
|
||||
const { eventName, eventData } = notification.params
|
||||
logEvent(
|
||||
`tengu_vscode_${eventName}`,
|
||||
eventData as { [key: string]: boolean | number | undefined },
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
// Send necessary experiment gates to VSCode immediately.
|
||||
const gates: Record<string, boolean | string> = {
|
||||
tengu_vscode_review_upsell: checkStatsigFeatureGate_CACHED_MAY_BE_STALE(
|
||||
'tengu_vscode_review_upsell',
|
||||
),
|
||||
tengu_vscode_onboarding: checkStatsigFeatureGate_CACHED_MAY_BE_STALE(
|
||||
'tengu_vscode_onboarding',
|
||||
),
|
||||
// Browser support.
|
||||
tengu_quiet_fern: getFeatureValue_CACHED_MAY_BE_STALE(
|
||||
'tengu_quiet_fern',
|
||||
false,
|
||||
),
|
||||
// In-band OAuth via claude_authenticate (vs. extension-native PKCE).
|
||||
tengu_vscode_cc_auth: getFeatureValue_CACHED_MAY_BE_STALE(
|
||||
'tengu_vscode_cc_auth',
|
||||
false,
|
||||
),
|
||||
}
|
||||
// Tri-state: 'enabled' | 'disabled' | 'opt-in'. Omit if unknown so VSCode
|
||||
// fails closed (treats absent as 'disabled').
|
||||
const autoModeState = readAutoModeEnabledState()
|
||||
if (autoModeState !== undefined) {
|
||||
gates.tengu_auto_mode_state = autoModeState
|
||||
}
|
||||
void client.client.notification({
|
||||
method: 'experiment_gates',
|
||||
params: { gates },
|
||||
})
|
||||
}
|
||||
}
|
||||
511
src/services/mcp/xaa.ts
Normal file
511
src/services/mcp/xaa.ts
Normal file
@@ -0,0 +1,511 @@
|
||||
/**
|
||||
* Cross-App Access (XAA) / Enterprise Managed Authorization (SEP-990)
|
||||
*
|
||||
* Obtains an MCP access token WITHOUT a browser consent screen by chaining:
|
||||
* 1. RFC 8693 Token Exchange at the IdP: id_token → ID-JAG
|
||||
* 2. RFC 7523 JWT Bearer Grant at the AS: ID-JAG → access_token
|
||||
*
|
||||
* Spec refs:
|
||||
* - ID-JAG (IETF draft): https://datatracker.ietf.org/doc/draft-ietf-oauth-identity-assertion-authz-grant/
|
||||
* - MCP ext-auth (SEP-990): https://github.com/modelcontextprotocol/ext-auth
|
||||
* - RFC 8693 (Token Exchange), RFC 7523 (JWT Bearer), RFC 9728 (PRM)
|
||||
*
|
||||
* Reference impl: ~/code/mcp/conformance/examples/clients/typescript/everything-client.ts:375-522
|
||||
*
|
||||
* Structure: four Layer-2 ops (aligned with TS SDK PR #1593's Layer-2 shapes so
|
||||
* a future SDK swap is mechanical) + one Layer-3 orchestrator that composes them.
|
||||
*/
|
||||
|
||||
import {
|
||||
discoverAuthorizationServerMetadata,
|
||||
discoverOAuthProtectedResourceMetadata,
|
||||
} from '@modelcontextprotocol/sdk/client/auth.js'
|
||||
import type { FetchLike } from '@modelcontextprotocol/sdk/shared/transport.js'
|
||||
import { z } from 'zod/v4'
|
||||
import { lazySchema } from '../../utils/lazySchema.js'
|
||||
import { logMCPDebug } from '../../utils/log.js'
|
||||
import { jsonStringify } from '../../utils/slowOperations.js'
|
||||
|
||||
const XAA_REQUEST_TIMEOUT_MS = 30000
|
||||
|
||||
const TOKEN_EXCHANGE_GRANT = 'urn:ietf:params:oauth:grant-type:token-exchange'
|
||||
const JWT_BEARER_GRANT = 'urn:ietf:params:oauth:grant-type:jwt-bearer'
|
||||
const ID_JAG_TOKEN_TYPE = 'urn:ietf:params:oauth:token-type:id-jag'
|
||||
const ID_TOKEN_TYPE = 'urn:ietf:params:oauth:token-type:id_token'
|
||||
|
||||
/**
|
||||
* Creates a fetch wrapper that enforces the XAA request timeout and optionally
|
||||
* composes a caller-provided abort signal. Using AbortSignal.any ensures the
|
||||
* user's cancel (e.g. Esc in the auth menu) actually aborts in-flight requests
|
||||
* rather than being clobbered by the timeout signal.
|
||||
*/
|
||||
function makeXaaFetch(abortSignal?: AbortSignal): FetchLike {
|
||||
return (url, init) => {
|
||||
const timeout = AbortSignal.timeout(XAA_REQUEST_TIMEOUT_MS)
|
||||
const signal = abortSignal
|
||||
? // eslint-disable-next-line eslint-plugin-n/no-unsupported-features/node-builtins
|
||||
AbortSignal.any([timeout, abortSignal])
|
||||
: timeout
|
||||
// eslint-disable-next-line eslint-plugin-n/no-unsupported-features/node-builtins
|
||||
return fetch(url, { ...init, signal })
|
||||
}
|
||||
}
|
||||
|
||||
const defaultFetch = makeXaaFetch()
|
||||
|
||||
/**
|
||||
* RFC 8414 §3.3 / RFC 9728 §3.3 identifier comparison. Roundtrip through URL
|
||||
* to apply RFC 3986 §6.2.2 syntax-based normalization (lowercases scheme+host,
|
||||
* drops default port), then strip trailing slash.
|
||||
*/
|
||||
function normalizeUrl(url: string): string {
|
||||
try {
|
||||
return new URL(url).href.replace(/\/$/, '')
|
||||
} catch {
|
||||
return url.replace(/\/$/, '')
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Thrown by requestJwtAuthorizationGrant when the IdP token-exchange leg
|
||||
* fails. Carries `shouldClearIdToken` so callers can decide whether to drop
|
||||
* the cached id_token based on OAuth error semantics (not substring matching):
|
||||
* - 4xx / invalid_grant / invalid_token → id_token is bad, clear it
|
||||
* - 5xx → IdP is down, id_token may still be valid, keep it
|
||||
* - 200 with structurally-invalid body → protocol violation, clear it
|
||||
*/
|
||||
export class XaaTokenExchangeError extends Error {
|
||||
readonly shouldClearIdToken: boolean
|
||||
constructor(message: string, shouldClearIdToken: boolean) {
|
||||
super(message)
|
||||
this.name = 'XaaTokenExchangeError'
|
||||
this.shouldClearIdToken = shouldClearIdToken
|
||||
}
|
||||
}
|
||||
|
||||
// Matches quoted values for known token-bearing keys regardless of nesting
|
||||
// depth. Works on both parsed-then-stringified bodies AND raw text() error
|
||||
// bodies from !res.ok paths — a misbehaving AS that echoes the request's
|
||||
// subject_token/assertion/client_secret in a 4xx error envelope must not leak
|
||||
// into debug logs.
|
||||
const SENSITIVE_TOKEN_RE =
|
||||
/"(access_token|refresh_token|id_token|assertion|subject_token|client_secret)"\s*:\s*"[^"]*"/g
|
||||
|
||||
function redactTokens(raw: unknown): string {
|
||||
const s = typeof raw === 'string' ? raw : jsonStringify(raw)
|
||||
return s.replace(SENSITIVE_TOKEN_RE, (_, k) => `"${k}":"[REDACTED]"`)
|
||||
}
|
||||
|
||||
// ─── Zod Schemas ────────────────────────────────────────────────────────────
|
||||
|
||||
const TokenExchangeResponseSchema = lazySchema(() =>
|
||||
z.object({
|
||||
access_token: z.string().optional(),
|
||||
issued_token_type: z.string().optional(),
|
||||
// z.coerce tolerates IdPs that send expires_in as a string (common in
|
||||
// PHP-backed IdPs) — technically non-conformant JSON but widespread.
|
||||
expires_in: z.coerce.number().optional(),
|
||||
scope: z.string().optional(),
|
||||
}),
|
||||
)
|
||||
|
||||
const JwtBearerResponseSchema = lazySchema(() =>
|
||||
z.object({
|
||||
access_token: z.string().min(1),
|
||||
// Many ASes omit token_type since Bearer is the only value anyone uses
|
||||
// (RFC 6750). Don't reject a valid access_token over a missing label.
|
||||
token_type: z.string().default('Bearer'),
|
||||
expires_in: z.coerce.number().optional(),
|
||||
scope: z.string().optional(),
|
||||
refresh_token: z.string().optional(),
|
||||
}),
|
||||
)
|
||||
|
||||
// ─── Layer 2: Discovery ─────────────────────────────────────────────────────
|
||||
|
||||
export type ProtectedResourceMetadata = {
|
||||
resource: string
|
||||
authorization_servers: string[]
|
||||
}
|
||||
|
||||
/**
|
||||
* RFC 9728 PRM discovery via SDK, plus RFC 9728 §3.3 resource-mismatch
|
||||
* validation (mix-up protection — TODO: upstream to SDK).
|
||||
*/
|
||||
export async function discoverProtectedResource(
|
||||
serverUrl: string,
|
||||
opts?: { fetchFn?: FetchLike },
|
||||
): Promise<ProtectedResourceMetadata> {
|
||||
let prm
|
||||
try {
|
||||
prm = await discoverOAuthProtectedResourceMetadata(
|
||||
serverUrl,
|
||||
undefined,
|
||||
opts?.fetchFn ?? defaultFetch,
|
||||
)
|
||||
} catch (e) {
|
||||
throw new Error(
|
||||
`XAA: PRM discovery failed: ${e instanceof Error ? e.message : String(e)}`,
|
||||
)
|
||||
}
|
||||
if (!prm.resource || !prm.authorization_servers?.[0]) {
|
||||
throw new Error(
|
||||
'XAA: PRM discovery failed: PRM missing resource or authorization_servers',
|
||||
)
|
||||
}
|
||||
if (normalizeUrl(prm.resource) !== normalizeUrl(serverUrl)) {
|
||||
throw new Error(
|
||||
`XAA: PRM discovery failed: PRM resource mismatch: expected ${serverUrl}, got ${prm.resource}`,
|
||||
)
|
||||
}
|
||||
return {
|
||||
resource: prm.resource,
|
||||
authorization_servers: prm.authorization_servers,
|
||||
}
|
||||
}
|
||||
|
||||
export type AuthorizationServerMetadata = {
|
||||
issuer: string
|
||||
token_endpoint: string
|
||||
grant_types_supported?: string[]
|
||||
token_endpoint_auth_methods_supported?: string[]
|
||||
}
|
||||
|
||||
/**
|
||||
* AS metadata discovery via SDK (RFC 8414 + OIDC fallback), plus RFC 8414
|
||||
* §3.3 issuer-mismatch validation (mix-up protection — TODO: upstream to SDK).
|
||||
*/
|
||||
export async function discoverAuthorizationServer(
|
||||
asUrl: string,
|
||||
opts?: { fetchFn?: FetchLike },
|
||||
): Promise<AuthorizationServerMetadata> {
|
||||
const meta = await discoverAuthorizationServerMetadata(asUrl, {
|
||||
fetchFn: opts?.fetchFn ?? defaultFetch,
|
||||
})
|
||||
if (!meta?.issuer || !meta.token_endpoint) {
|
||||
throw new Error(
|
||||
`XAA: AS metadata discovery failed: no valid metadata at ${asUrl}`,
|
||||
)
|
||||
}
|
||||
if (normalizeUrl(meta.issuer) !== normalizeUrl(asUrl)) {
|
||||
throw new Error(
|
||||
`XAA: AS metadata discovery failed: issuer mismatch: expected ${asUrl}, got ${meta.issuer}`,
|
||||
)
|
||||
}
|
||||
// RFC 8414 §3.3 / RFC 9728 §3 require HTTPS. A PRM-advertised http:// AS
|
||||
// that self-consistently reports an http:// issuer would pass the mismatch
|
||||
// check above, then we'd POST id_token + client_secret over plaintext.
|
||||
if (new URL(meta.token_endpoint).protocol !== 'https:') {
|
||||
throw new Error(
|
||||
`XAA: refusing non-HTTPS token endpoint: ${meta.token_endpoint}`,
|
||||
)
|
||||
}
|
||||
return {
|
||||
issuer: meta.issuer,
|
||||
token_endpoint: meta.token_endpoint,
|
||||
grant_types_supported: meta.grant_types_supported,
|
||||
token_endpoint_auth_methods_supported:
|
||||
meta.token_endpoint_auth_methods_supported,
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Layer 2: Exchange ──────────────────────────────────────────────────────
|
||||
|
||||
export type JwtAuthGrantResult = {
|
||||
/** The ID-JAG (Identity Assertion Authorization Grant) */
|
||||
jwtAuthGrant: string
|
||||
expiresIn?: number
|
||||
scope?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* RFC 8693 Token Exchange at the IdP: id_token → ID-JAG.
|
||||
* Validates `issued_token_type` is `urn:ietf:params:oauth:token-type:id-jag`.
|
||||
*
|
||||
* `clientSecret` is optional — sent via `client_secret_post` if present.
|
||||
* Some IdPs register the client as confidential even when they advertise
|
||||
* `token_endpoint_auth_method: "none"`.
|
||||
*
|
||||
* TODO(xaa-ga): consult `token_endpoint_auth_methods_supported` from IdP
|
||||
* OIDC metadata and support `client_secret_basic`, mirroring the AS-side
|
||||
* selection in `performCrossAppAccess`. All major IdPs accept POST today.
|
||||
*/
|
||||
export async function requestJwtAuthorizationGrant(opts: {
|
||||
tokenEndpoint: string
|
||||
audience: string
|
||||
resource: string
|
||||
idToken: string
|
||||
clientId: string
|
||||
clientSecret?: string
|
||||
scope?: string
|
||||
fetchFn?: FetchLike
|
||||
}): Promise<JwtAuthGrantResult> {
|
||||
const fetchFn = opts.fetchFn ?? defaultFetch
|
||||
const params = new URLSearchParams({
|
||||
grant_type: TOKEN_EXCHANGE_GRANT,
|
||||
requested_token_type: ID_JAG_TOKEN_TYPE,
|
||||
audience: opts.audience,
|
||||
resource: opts.resource,
|
||||
subject_token: opts.idToken,
|
||||
subject_token_type: ID_TOKEN_TYPE,
|
||||
client_id: opts.clientId,
|
||||
})
|
||||
if (opts.clientSecret) {
|
||||
params.set('client_secret', opts.clientSecret)
|
||||
}
|
||||
if (opts.scope) {
|
||||
params.set('scope', opts.scope)
|
||||
}
|
||||
|
||||
const res = await fetchFn(opts.tokenEndpoint, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: params,
|
||||
})
|
||||
if (!res.ok) {
|
||||
const body = redactTokens(await res.text()).slice(0, 200)
|
||||
// 4xx → id_token rejected (invalid_grant etc.), clear cache.
|
||||
// 5xx → IdP outage, id_token may still be valid, preserve it.
|
||||
const shouldClear = res.status < 500
|
||||
throw new XaaTokenExchangeError(
|
||||
`XAA: token exchange failed: HTTP ${res.status}: ${body}`,
|
||||
shouldClear,
|
||||
)
|
||||
}
|
||||
let rawExchange: unknown
|
||||
try {
|
||||
rawExchange = await res.json()
|
||||
} catch {
|
||||
// Transient network condition (captive portal, proxy) — don't clear id_token.
|
||||
throw new XaaTokenExchangeError(
|
||||
`XAA: token exchange returned non-JSON (captive portal?) at ${opts.tokenEndpoint}`,
|
||||
false,
|
||||
)
|
||||
}
|
||||
const exchangeParsed = TokenExchangeResponseSchema().safeParse(rawExchange)
|
||||
if (!exchangeParsed.success) {
|
||||
throw new XaaTokenExchangeError(
|
||||
`XAA: token exchange response did not match expected shape: ${redactTokens(rawExchange)}`,
|
||||
true,
|
||||
)
|
||||
}
|
||||
const result = exchangeParsed.data
|
||||
if (!result.access_token) {
|
||||
throw new XaaTokenExchangeError(
|
||||
`XAA: token exchange response missing access_token: ${redactTokens(result)}`,
|
||||
true,
|
||||
)
|
||||
}
|
||||
if (result.issued_token_type !== ID_JAG_TOKEN_TYPE) {
|
||||
throw new XaaTokenExchangeError(
|
||||
`XAA: token exchange returned unexpected issued_token_type: ${result.issued_token_type}`,
|
||||
true,
|
||||
)
|
||||
}
|
||||
return {
|
||||
jwtAuthGrant: result.access_token,
|
||||
expiresIn: result.expires_in,
|
||||
scope: result.scope,
|
||||
}
|
||||
}
|
||||
|
||||
export type XaaTokenResult = {
|
||||
access_token: string
|
||||
token_type: string
|
||||
expires_in?: number
|
||||
scope?: string
|
||||
refresh_token?: string
|
||||
}
|
||||
|
||||
export type XaaResult = XaaTokenResult & {
|
||||
/**
|
||||
* The AS issuer URL discovered via PRM. Callers must persist this as
|
||||
* `discoveryState.authorizationServerUrl` so that refresh (auth.ts _doRefresh)
|
||||
* and revocation (revokeServerTokens) can locate the token/revocation
|
||||
* endpoints — the MCP URL is not the AS URL in typical XAA setups.
|
||||
*/
|
||||
authorizationServerUrl: string
|
||||
}
|
||||
|
||||
/**
|
||||
* RFC 7523 JWT Bearer Grant at the AS: ID-JAG → access_token.
|
||||
*
|
||||
* `authMethod` defaults to `client_secret_basic` (Base64 header, not body
|
||||
* params) — the SEP-990 conformance test requires this. Only set
|
||||
* `client_secret_post` if the AS explicitly requires it.
|
||||
*/
|
||||
export async function exchangeJwtAuthGrant(opts: {
|
||||
tokenEndpoint: string
|
||||
assertion: string
|
||||
clientId: string
|
||||
clientSecret: string
|
||||
authMethod?: 'client_secret_basic' | 'client_secret_post'
|
||||
scope?: string
|
||||
fetchFn?: FetchLike
|
||||
}): Promise<XaaTokenResult> {
|
||||
const fetchFn = opts.fetchFn ?? defaultFetch
|
||||
const authMethod = opts.authMethod ?? 'client_secret_basic'
|
||||
|
||||
const params = new URLSearchParams({
|
||||
grant_type: JWT_BEARER_GRANT,
|
||||
assertion: opts.assertion,
|
||||
})
|
||||
if (opts.scope) {
|
||||
params.set('scope', opts.scope)
|
||||
}
|
||||
|
||||
const headers: Record<string, string> = {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
}
|
||||
if (authMethod === 'client_secret_basic') {
|
||||
const basicAuth = Buffer.from(
|
||||
`${encodeURIComponent(opts.clientId)}:${encodeURIComponent(opts.clientSecret)}`,
|
||||
).toString('base64')
|
||||
headers.Authorization = `Basic ${basicAuth}`
|
||||
} else {
|
||||
params.set('client_id', opts.clientId)
|
||||
params.set('client_secret', opts.clientSecret)
|
||||
}
|
||||
|
||||
const res = await fetchFn(opts.tokenEndpoint, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: params,
|
||||
})
|
||||
if (!res.ok) {
|
||||
const body = redactTokens(await res.text()).slice(0, 200)
|
||||
throw new Error(`XAA: jwt-bearer grant failed: HTTP ${res.status}: ${body}`)
|
||||
}
|
||||
let rawTokens: unknown
|
||||
try {
|
||||
rawTokens = await res.json()
|
||||
} catch {
|
||||
throw new Error(
|
||||
`XAA: jwt-bearer grant returned non-JSON (captive portal?) at ${opts.tokenEndpoint}`,
|
||||
)
|
||||
}
|
||||
const tokensParsed = JwtBearerResponseSchema().safeParse(rawTokens)
|
||||
if (!tokensParsed.success) {
|
||||
throw new Error(
|
||||
`XAA: jwt-bearer response did not match expected shape: ${redactTokens(rawTokens)}`,
|
||||
)
|
||||
}
|
||||
return tokensParsed.data
|
||||
}
|
||||
|
||||
// ─── Layer 3: Orchestrator ──────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Config needed to run the full XAA orchestrator.
|
||||
* Mirrors the conformance test context shape (see ClientConformanceContextSchema).
|
||||
*/
|
||||
export type XaaConfig = {
|
||||
/** Client ID registered at the MCP server's authorization server */
|
||||
clientId: string
|
||||
/** Client secret for the MCP server's authorization server */
|
||||
clientSecret: string
|
||||
/** Client ID registered at the IdP (for the token-exchange request) */
|
||||
idpClientId: string
|
||||
/** Optional IdP client secret (client_secret_post) — some IdPs require it */
|
||||
idpClientSecret?: string
|
||||
/** The user's OIDC id_token from the IdP login */
|
||||
idpIdToken: string
|
||||
/** IdP token endpoint (where to send the RFC 8693 token-exchange) */
|
||||
idpTokenEndpoint: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Full XAA flow: PRM → AS metadata → token-exchange → jwt-bearer → access_token.
|
||||
* Thin composition of the four Layer-2 ops. Used by performMCPXaaAuth,
|
||||
* ClaudeAuthProvider.xaaRefresh, and the try-xaa*.ts debug scripts.
|
||||
*
|
||||
* @param serverUrl The MCP server URL (e.g. `https://mcp.example.com/mcp`)
|
||||
* @param config IdP + AS credentials
|
||||
* @param serverName Server name for debug logging
|
||||
*/
|
||||
export async function performCrossAppAccess(
|
||||
serverUrl: string,
|
||||
config: XaaConfig,
|
||||
serverName = 'xaa',
|
||||
abortSignal?: AbortSignal,
|
||||
): Promise<XaaResult> {
|
||||
const fetchFn = makeXaaFetch(abortSignal)
|
||||
|
||||
logMCPDebug(serverName, `XAA: discovering PRM for ${serverUrl}`)
|
||||
const prm = await discoverProtectedResource(serverUrl, { fetchFn })
|
||||
logMCPDebug(
|
||||
serverName,
|
||||
`XAA: discovered resource=${prm.resource} ASes=[${prm.authorization_servers.join(', ')}]`,
|
||||
)
|
||||
|
||||
// Try each advertised AS in order. grant_types_supported is OPTIONAL per
|
||||
// RFC 8414 §2 — only skip if the AS explicitly advertises a list that omits
|
||||
// jwt-bearer. If absent, let the token endpoint decide.
|
||||
let asMeta: AuthorizationServerMetadata | undefined
|
||||
const asErrors: string[] = []
|
||||
for (const asUrl of prm.authorization_servers) {
|
||||
let candidate: AuthorizationServerMetadata
|
||||
try {
|
||||
candidate = await discoverAuthorizationServer(asUrl, { fetchFn })
|
||||
} catch (e) {
|
||||
if (abortSignal?.aborted) throw e
|
||||
asErrors.push(`${asUrl}: ${e instanceof Error ? e.message : String(e)}`)
|
||||
continue
|
||||
}
|
||||
if (
|
||||
candidate.grant_types_supported &&
|
||||
!candidate.grant_types_supported.includes(JWT_BEARER_GRANT)
|
||||
) {
|
||||
asErrors.push(
|
||||
`${asUrl}: does not advertise jwt-bearer grant (supported: ${candidate.grant_types_supported.join(', ')})`,
|
||||
)
|
||||
continue
|
||||
}
|
||||
asMeta = candidate
|
||||
break
|
||||
}
|
||||
if (!asMeta) {
|
||||
throw new Error(
|
||||
`XAA: no authorization server supports jwt-bearer. Tried: ${asErrors.join('; ')}`,
|
||||
)
|
||||
}
|
||||
// Pick auth method from what the AS advertises. We handle
|
||||
// client_secret_basic and client_secret_post; if the AS only supports post,
|
||||
// honor that, else default to basic (SEP-990 conformance expectation).
|
||||
const authMethods = asMeta.token_endpoint_auth_methods_supported
|
||||
const authMethod: 'client_secret_basic' | 'client_secret_post' =
|
||||
authMethods &&
|
||||
!authMethods.includes('client_secret_basic') &&
|
||||
authMethods.includes('client_secret_post')
|
||||
? 'client_secret_post'
|
||||
: 'client_secret_basic'
|
||||
logMCPDebug(
|
||||
serverName,
|
||||
`XAA: AS issuer=${asMeta.issuer} token_endpoint=${asMeta.token_endpoint} auth_method=${authMethod}`,
|
||||
)
|
||||
|
||||
logMCPDebug(serverName, `XAA: exchanging id_token for ID-JAG at IdP`)
|
||||
const jag = await requestJwtAuthorizationGrant({
|
||||
tokenEndpoint: config.idpTokenEndpoint,
|
||||
audience: asMeta.issuer,
|
||||
resource: prm.resource,
|
||||
idToken: config.idpIdToken,
|
||||
clientId: config.idpClientId,
|
||||
clientSecret: config.idpClientSecret,
|
||||
fetchFn,
|
||||
})
|
||||
logMCPDebug(serverName, `XAA: ID-JAG obtained`)
|
||||
|
||||
logMCPDebug(serverName, `XAA: exchanging ID-JAG for access_token at AS`)
|
||||
const tokens = await exchangeJwtAuthGrant({
|
||||
tokenEndpoint: asMeta.token_endpoint,
|
||||
assertion: jag.jwtAuthGrant,
|
||||
clientId: config.clientId,
|
||||
clientSecret: config.clientSecret,
|
||||
authMethod,
|
||||
fetchFn,
|
||||
})
|
||||
logMCPDebug(serverName, `XAA: access_token obtained`)
|
||||
|
||||
return { ...tokens, authorizationServerUrl: asMeta.issuer }
|
||||
}
|
||||
487
src/services/mcp/xaaIdpLogin.ts
Normal file
487
src/services/mcp/xaaIdpLogin.ts
Normal file
@@ -0,0 +1,487 @@
|
||||
/**
|
||||
* XAA IdP Login — acquires an OIDC id_token from an enterprise IdP via the
|
||||
* standard authorization_code + PKCE flow, then caches it by IdP issuer.
|
||||
*
|
||||
* This is the "one browser pop" in the XAA value prop: one IdP login → N silent
|
||||
* MCP server auths. The id_token is cached in the keychain and reused until expiry.
|
||||
*/
|
||||
|
||||
import {
|
||||
exchangeAuthorization,
|
||||
startAuthorization,
|
||||
} from '@modelcontextprotocol/sdk/client/auth.js'
|
||||
import {
|
||||
type OAuthClientInformation,
|
||||
type OpenIdProviderDiscoveryMetadata,
|
||||
OpenIdProviderDiscoveryMetadataSchema,
|
||||
} from '@modelcontextprotocol/sdk/shared/auth.js'
|
||||
import { randomBytes } from 'crypto'
|
||||
import { createServer, type Server } from 'http'
|
||||
import { parse } from 'url'
|
||||
import xss from 'xss'
|
||||
import { openBrowser } from '../../utils/browser.js'
|
||||
import { isEnvTruthy } from '../../utils/envUtils.js'
|
||||
import { toError } from '../../utils/errors.js'
|
||||
import { logMCPDebug } from '../../utils/log.js'
|
||||
import { getPlatform } from '../../utils/platform.js'
|
||||
import { getSecureStorage } from '../../utils/secureStorage/index.js'
|
||||
import { getInitialSettings } from '../../utils/settings/settings.js'
|
||||
import { jsonParse } from '../../utils/slowOperations.js'
|
||||
import { buildRedirectUri, findAvailablePort } from './oauthPort.js'
|
||||
|
||||
export function isXaaEnabled(): boolean {
|
||||
return isEnvTruthy(process.env.CLAUDE_CODE_ENABLE_XAA)
|
||||
}
|
||||
|
||||
export type XaaIdpSettings = {
|
||||
issuer: string
|
||||
clientId: string
|
||||
callbackPort?: number
|
||||
}
|
||||
|
||||
/**
|
||||
* Typed accessor for settings.xaaIdp. The field is env-gated in SettingsSchema
|
||||
* so it doesn't surface in SDK types/docs — which means the inferred settings
|
||||
* type doesn't have it at compile time. This is the one cast.
|
||||
*/
|
||||
export function getXaaIdpSettings(): XaaIdpSettings | undefined {
|
||||
return (getInitialSettings() as { xaaIdp?: XaaIdpSettings }).xaaIdp
|
||||
}
|
||||
|
||||
const IDP_LOGIN_TIMEOUT_MS = 5 * 60 * 1000
|
||||
const IDP_REQUEST_TIMEOUT_MS = 30000
|
||||
const ID_TOKEN_EXPIRY_BUFFER_S = 60
|
||||
|
||||
export type IdpLoginOptions = {
|
||||
idpIssuer: string
|
||||
idpClientId: string
|
||||
/**
|
||||
* Optional IdP client secret for confidential clients. Auth method
|
||||
* (client_secret_post, client_secret_basic, none) is chosen per IdP
|
||||
* metadata. Omit for public clients (PKCE only).
|
||||
*/
|
||||
idpClientSecret?: string
|
||||
/**
|
||||
* Fixed callback port. If omitted, a random port is chosen.
|
||||
* Use this when the IdP client is pre-registered with a specific loopback
|
||||
* redirect URI (RFC 8252 §7.3 says IdPs SHOULD accept any port for
|
||||
* http://localhost, but many don't).
|
||||
*/
|
||||
callbackPort?: number
|
||||
/** Called with the authorization URL before (or instead of) opening the browser */
|
||||
onAuthorizationUrl?: (url: string) => void
|
||||
/** If true, don't auto-open the browser — just call onAuthorizationUrl */
|
||||
skipBrowserOpen?: boolean
|
||||
abortSignal?: AbortSignal
|
||||
}
|
||||
|
||||
/**
|
||||
* Normalize an IdP issuer URL for use as a cache key: strip trailing slashes,
|
||||
* lowercase host. Issuers from config and from OIDC discovery may differ
|
||||
* cosmetically but should hit the same cache slot. Exported so the setup
|
||||
* command can compare issuers using the same normalization as keychain ops.
|
||||
*/
|
||||
export function issuerKey(issuer: string): string {
|
||||
try {
|
||||
const u = new URL(issuer)
|
||||
u.pathname = u.pathname.replace(/\/+$/, '')
|
||||
u.host = u.host.toLowerCase()
|
||||
return u.toString()
|
||||
} catch {
|
||||
return issuer.replace(/\/+$/, '')
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Read a cached id_token for the given IdP issuer from secure storage.
|
||||
* Returns undefined if missing or within ID_TOKEN_EXPIRY_BUFFER_S of expiring.
|
||||
*/
|
||||
export function getCachedIdpIdToken(idpIssuer: string): string | undefined {
|
||||
const storage = getSecureStorage()
|
||||
const data = storage.read()
|
||||
const entry = data?.mcpXaaIdp?.[issuerKey(idpIssuer)]
|
||||
if (!entry) return undefined
|
||||
const remainingMs = entry.expiresAt - Date.now()
|
||||
if (remainingMs <= ID_TOKEN_EXPIRY_BUFFER_S * 1000) return undefined
|
||||
return entry.idToken
|
||||
}
|
||||
|
||||
function saveIdpIdToken(
|
||||
idpIssuer: string,
|
||||
idToken: string,
|
||||
expiresAt: number,
|
||||
): void {
|
||||
const storage = getSecureStorage()
|
||||
const existing = storage.read() || {}
|
||||
storage.update({
|
||||
...existing,
|
||||
mcpXaaIdp: {
|
||||
...existing.mcpXaaIdp,
|
||||
[issuerKey(idpIssuer)]: { idToken, expiresAt },
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Save an externally-obtained id_token into the XAA cache — the exact slot
|
||||
* getCachedIdpIdToken/acquireIdpIdToken read from. Used by conformance testing
|
||||
* where the mock IdP hands us a pre-signed token but doesn't serve /authorize.
|
||||
*
|
||||
* Parses the JWT's exp claim for cache TTL (same as acquireIdpIdToken).
|
||||
* Returns the expiresAt it computed so the caller can report it.
|
||||
*/
|
||||
export function saveIdpIdTokenFromJwt(
|
||||
idpIssuer: string,
|
||||
idToken: string,
|
||||
): number {
|
||||
const expFromJwt = jwtExp(idToken)
|
||||
const expiresAt = expFromJwt ? expFromJwt * 1000 : Date.now() + 3600 * 1000
|
||||
saveIdpIdToken(idpIssuer, idToken, expiresAt)
|
||||
return expiresAt
|
||||
}
|
||||
|
||||
export function clearIdpIdToken(idpIssuer: string): void {
|
||||
const storage = getSecureStorage()
|
||||
const existing = storage.read()
|
||||
const key = issuerKey(idpIssuer)
|
||||
if (!existing?.mcpXaaIdp?.[key]) return
|
||||
delete existing.mcpXaaIdp[key]
|
||||
storage.update(existing)
|
||||
}
|
||||
|
||||
/**
|
||||
* Save an IdP client secret to secure storage, keyed by IdP issuer.
|
||||
* Separate from MCP server AS secrets — different trust domain.
|
||||
* Returns the storage update result so callers can surface keychain
|
||||
* failures (locked keychain, `security` nonzero exit) instead of
|
||||
* silently dropping the secret and failing later with invalid_client.
|
||||
*/
|
||||
export function saveIdpClientSecret(
|
||||
idpIssuer: string,
|
||||
clientSecret: string,
|
||||
): { success: boolean; warning?: string } {
|
||||
const storage = getSecureStorage()
|
||||
const existing = storage.read() || {}
|
||||
return storage.update({
|
||||
...existing,
|
||||
mcpXaaIdpConfig: {
|
||||
...existing.mcpXaaIdpConfig,
|
||||
[issuerKey(idpIssuer)]: { clientSecret },
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Read the IdP client secret for the given issuer from secure storage.
|
||||
*/
|
||||
export function getIdpClientSecret(idpIssuer: string): string | undefined {
|
||||
const storage = getSecureStorage()
|
||||
const data = storage.read()
|
||||
return data?.mcpXaaIdpConfig?.[issuerKey(idpIssuer)]?.clientSecret
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove the IdP client secret for the given issuer from secure storage.
|
||||
* Used by `claude mcp xaa clear`.
|
||||
*/
|
||||
export function clearIdpClientSecret(idpIssuer: string): void {
|
||||
const storage = getSecureStorage()
|
||||
const existing = storage.read()
|
||||
const key = issuerKey(idpIssuer)
|
||||
if (!existing?.mcpXaaIdpConfig?.[key]) return
|
||||
delete existing.mcpXaaIdpConfig[key]
|
||||
storage.update(existing)
|
||||
}
|
||||
|
||||
// OIDC Discovery §4.1 says `{issuer}/.well-known/openid-configuration` — path
|
||||
// APPEND, not replace. `new URL('/.well-known/...', issuer)` with a leading
|
||||
// slash is a WHATWG absolute-path reference and drops the issuer's pathname,
|
||||
// breaking Azure AD (`login.microsoftonline.com/{tenant}/v2.0`), Okta custom
|
||||
// auth servers, and Keycloak realms. Trailing-slash base + relative path is
|
||||
// the fix. Exported because auth.ts needs the same discovery.
|
||||
export async function discoverOidc(
|
||||
idpIssuer: string,
|
||||
): Promise<OpenIdProviderDiscoveryMetadata> {
|
||||
const base = idpIssuer.endsWith('/') ? idpIssuer : idpIssuer + '/'
|
||||
const url = new URL('.well-known/openid-configuration', base)
|
||||
// eslint-disable-next-line eslint-plugin-n/no-unsupported-features/node-builtins
|
||||
const res = await fetch(url, {
|
||||
headers: { Accept: 'application/json' },
|
||||
signal: AbortSignal.timeout(IDP_REQUEST_TIMEOUT_MS),
|
||||
})
|
||||
if (!res.ok) {
|
||||
throw new Error(
|
||||
`XAA IdP: OIDC discovery failed: HTTP ${res.status} at ${url}`,
|
||||
)
|
||||
}
|
||||
// Captive portals and proxy auth pages return 200 with HTML. res.json()
|
||||
// throws a raw SyntaxError before safeParse can give a useful message.
|
||||
let body: unknown
|
||||
try {
|
||||
body = await res.json()
|
||||
} catch {
|
||||
throw new Error(
|
||||
`XAA IdP: OIDC discovery returned non-JSON at ${url} (captive portal or proxy?)`,
|
||||
)
|
||||
}
|
||||
const parsed = OpenIdProviderDiscoveryMetadataSchema.safeParse(body)
|
||||
if (!parsed.success) {
|
||||
throw new Error(`XAA IdP: invalid OIDC metadata: ${parsed.error.message}`)
|
||||
}
|
||||
if (new URL(parsed.data.token_endpoint).protocol !== 'https:') {
|
||||
throw new Error(
|
||||
`XAA IdP: refusing non-HTTPS token endpoint: ${parsed.data.token_endpoint}`,
|
||||
)
|
||||
}
|
||||
return parsed.data
|
||||
}
|
||||
|
||||
/**
|
||||
* Decode the exp claim from a JWT without verifying its signature.
|
||||
* Returns undefined if parsing fails or exp is absent. Used only to
|
||||
* derive a cache TTL.
|
||||
*
|
||||
* Why no signature/iss/aud/nonce validation: per SEP-990, this id_token
|
||||
* is the RFC 8693 subject_token in a token-exchange at the IdP's own
|
||||
* token endpoint. The IdP validates its own token there. An attacker who
|
||||
* can mint a token that fools the IdP has no need to fool us first; an
|
||||
* attacker who can't, hands us garbage and gets a 401 from the IdP. The
|
||||
* --id-token injection seam is likewise safe: bad input → rejected later,
|
||||
* no privesc. Client-side verification would add code and no security.
|
||||
*/
|
||||
function jwtExp(jwt: string): number | undefined {
|
||||
const parts = jwt.split('.')
|
||||
if (parts.length !== 3) return undefined
|
||||
try {
|
||||
const payload = jsonParse(
|
||||
Buffer.from(parts[1]!, 'base64url').toString('utf-8'),
|
||||
) as { exp?: number }
|
||||
return typeof payload.exp === 'number' ? payload.exp : undefined
|
||||
} catch {
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Wait for the OAuth authorization code on a local callback server.
|
||||
* Returns the code once /callback is hit with a matching state.
|
||||
*
|
||||
* `onListening` fires after the socket is actually bound — use it to defer
|
||||
* browser-open so EADDRINUSE surfaces before a spurious tab pops open.
|
||||
*/
|
||||
function waitForCallback(
|
||||
port: number,
|
||||
expectedState: string,
|
||||
abortSignal: AbortSignal | undefined,
|
||||
onListening: () => void,
|
||||
): Promise<string> {
|
||||
let server: Server | null = null
|
||||
let timeoutId: NodeJS.Timeout | null = null
|
||||
let abortHandler: (() => void) | null = null
|
||||
const cleanup = () => {
|
||||
server?.removeAllListeners()
|
||||
// Defensive: removeAllListeners() strips the error handler, so swallow any late error during close
|
||||
server?.on('error', () => {})
|
||||
server?.close()
|
||||
server = null
|
||||
if (timeoutId) {
|
||||
clearTimeout(timeoutId)
|
||||
timeoutId = null
|
||||
}
|
||||
if (abortSignal && abortHandler) {
|
||||
abortSignal.removeEventListener('abort', abortHandler)
|
||||
abortHandler = null
|
||||
}
|
||||
}
|
||||
return new Promise<string>((resolve, reject) => {
|
||||
let resolved = false
|
||||
const resolveOnce = (v: string) => {
|
||||
if (resolved) return
|
||||
resolved = true
|
||||
cleanup()
|
||||
resolve(v)
|
||||
}
|
||||
const rejectOnce = (e: Error) => {
|
||||
if (resolved) return
|
||||
resolved = true
|
||||
cleanup()
|
||||
reject(e)
|
||||
}
|
||||
|
||||
if (abortSignal) {
|
||||
abortHandler = () => rejectOnce(new Error('XAA IdP: login cancelled'))
|
||||
if (abortSignal.aborted) {
|
||||
abortHandler()
|
||||
return
|
||||
}
|
||||
abortSignal.addEventListener('abort', abortHandler, { once: true })
|
||||
}
|
||||
|
||||
server = createServer((req, res) => {
|
||||
const parsed = parse(req.url || '', true)
|
||||
if (parsed.pathname !== '/callback') {
|
||||
res.writeHead(404)
|
||||
res.end()
|
||||
return
|
||||
}
|
||||
const code = parsed.query.code as string | undefined
|
||||
const state = parsed.query.state as string | undefined
|
||||
const err = parsed.query.error as string | undefined
|
||||
|
||||
if (err) {
|
||||
const desc = parsed.query.error_description as string | undefined
|
||||
const safeErr = xss(err)
|
||||
const safeDesc = desc ? xss(desc) : ''
|
||||
res.writeHead(400, { 'Content-Type': 'text/html' })
|
||||
res.end(
|
||||
`<html><body><h3>IdP login failed</h3><p>${safeErr}</p><p>${safeDesc}</p></body></html>`,
|
||||
)
|
||||
rejectOnce(new Error(`XAA IdP: ${err}${desc ? ` — ${desc}` : ''}`))
|
||||
return
|
||||
}
|
||||
|
||||
if (state !== expectedState) {
|
||||
res.writeHead(400, { 'Content-Type': 'text/html' })
|
||||
res.end('<html><body><h3>State mismatch</h3></body></html>')
|
||||
rejectOnce(new Error('XAA IdP: state mismatch (possible CSRF)'))
|
||||
return
|
||||
}
|
||||
|
||||
if (!code) {
|
||||
res.writeHead(400, { 'Content-Type': 'text/html' })
|
||||
res.end('<html><body><h3>Missing code</h3></body></html>')
|
||||
rejectOnce(new Error('XAA IdP: callback missing code'))
|
||||
return
|
||||
}
|
||||
|
||||
res.writeHead(200, { 'Content-Type': 'text/html' })
|
||||
res.end(
|
||||
'<html><body><h3>IdP login complete — you can close this window.</h3></body></html>',
|
||||
)
|
||||
resolveOnce(code)
|
||||
})
|
||||
|
||||
server.on('error', (err: NodeJS.ErrnoException) => {
|
||||
if (err.code === 'EADDRINUSE') {
|
||||
const findCmd =
|
||||
getPlatform() === 'windows'
|
||||
? `netstat -ano | findstr :${port}`
|
||||
: `lsof -ti:${port} -sTCP:LISTEN`
|
||||
rejectOnce(
|
||||
new Error(
|
||||
`XAA IdP: callback port ${port} is already in use. Run \`${findCmd}\` to find the holder.`,
|
||||
),
|
||||
)
|
||||
} else {
|
||||
rejectOnce(new Error(`XAA IdP: callback server failed: ${err.message}`))
|
||||
}
|
||||
})
|
||||
|
||||
server.listen(port, '127.0.0.1', () => {
|
||||
try {
|
||||
onListening()
|
||||
} catch (e) {
|
||||
rejectOnce(toError(e))
|
||||
}
|
||||
})
|
||||
server.unref()
|
||||
timeoutId = setTimeout(
|
||||
rej => rej(new Error('XAA IdP: login timed out')),
|
||||
IDP_LOGIN_TIMEOUT_MS,
|
||||
rejectOnce,
|
||||
)
|
||||
timeoutId.unref()
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Acquire an id_token from the IdP: return cached if valid, otherwise run
|
||||
* the full OIDC authorization_code + PKCE flow (one browser pop).
|
||||
*/
|
||||
export async function acquireIdpIdToken(
|
||||
opts: IdpLoginOptions,
|
||||
): Promise<string> {
|
||||
const { idpIssuer, idpClientId } = opts
|
||||
|
||||
const cached = getCachedIdpIdToken(idpIssuer)
|
||||
if (cached) {
|
||||
logMCPDebug('xaa', `Using cached id_token for ${idpIssuer}`)
|
||||
return cached
|
||||
}
|
||||
|
||||
logMCPDebug('xaa', `No cached id_token for ${idpIssuer}; starting OIDC login`)
|
||||
|
||||
const metadata = await discoverOidc(idpIssuer)
|
||||
const port = opts.callbackPort ?? (await findAvailablePort())
|
||||
const redirectUri = buildRedirectUri(port)
|
||||
const state = randomBytes(32).toString('base64url')
|
||||
const clientInformation: OAuthClientInformation = {
|
||||
client_id: idpClientId,
|
||||
...(opts.idpClientSecret ? { client_secret: opts.idpClientSecret } : {}),
|
||||
}
|
||||
|
||||
const { authorizationUrl, codeVerifier } = await startAuthorization(
|
||||
idpIssuer,
|
||||
{
|
||||
metadata,
|
||||
clientInformation,
|
||||
redirectUrl: redirectUri,
|
||||
scope: 'openid',
|
||||
state,
|
||||
},
|
||||
)
|
||||
|
||||
// Open the browser only after the socket is actually bound — listen() is
|
||||
// async, and on the fixed-callbackPort path EADDRINUSE otherwise surfaces
|
||||
// after a spurious tab has already popped. Mirrors the auth.ts pattern of
|
||||
// wrapping sdkAuth inside server.listen's callback.
|
||||
const authorizationCode = await waitForCallback(
|
||||
port,
|
||||
state,
|
||||
opts.abortSignal,
|
||||
() => {
|
||||
if (opts.onAuthorizationUrl) {
|
||||
opts.onAuthorizationUrl(authorizationUrl.toString())
|
||||
}
|
||||
if (!opts.skipBrowserOpen) {
|
||||
logMCPDebug('xaa', `Opening browser to IdP authorization endpoint`)
|
||||
void openBrowser(authorizationUrl.toString())
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
const tokens = await exchangeAuthorization(idpIssuer, {
|
||||
metadata,
|
||||
clientInformation,
|
||||
authorizationCode,
|
||||
codeVerifier,
|
||||
redirectUri,
|
||||
fetchFn: (url, init) =>
|
||||
// eslint-disable-next-line eslint-plugin-n/no-unsupported-features/node-builtins
|
||||
fetch(url, {
|
||||
...init,
|
||||
signal: AbortSignal.timeout(IDP_REQUEST_TIMEOUT_MS),
|
||||
}),
|
||||
})
|
||||
if (!tokens.id_token) {
|
||||
throw new Error(
|
||||
'XAA IdP: token response missing id_token (check scope=openid)',
|
||||
)
|
||||
}
|
||||
|
||||
// Prefer the id_token's own exp claim; fall back to expires_in.
|
||||
// expires_in is for the access_token and may differ from the id_token
|
||||
// lifetime. If neither is present, default to 1h.
|
||||
const expFromJwt = jwtExp(tokens.id_token)
|
||||
const expiresAt = expFromJwt
|
||||
? expFromJwt * 1000
|
||||
: Date.now() + (tokens.expires_in ?? 3600) * 1000
|
||||
|
||||
saveIdpIdToken(idpIssuer, tokens.id_token, expiresAt)
|
||||
logMCPDebug(
|
||||
'xaa',
|
||||
`Cached id_token for ${idpIssuer} (expires ${new Date(expiresAt).toISOString()})`,
|
||||
)
|
||||
|
||||
return tokens.id_token
|
||||
}
|
||||
41
src/services/mcpServerApproval.tsx
Normal file
41
src/services/mcpServerApproval.tsx
Normal file
@@ -0,0 +1,41 @@
|
||||
import React from 'react';
|
||||
import { MCPServerApprovalDialog } from '../components/MCPServerApprovalDialog.js';
|
||||
import { MCPServerMultiselectDialog } from '../components/MCPServerMultiselectDialog.js';
|
||||
import type { Root } from '../ink.js';
|
||||
import { KeybindingSetup } from '../keybindings/KeybindingProviderSetup.js';
|
||||
import { AppStateProvider } from '../state/AppState.js';
|
||||
import { getMcpConfigsByScope } from './mcp/config.js';
|
||||
import { getProjectMcpServerStatus } from './mcp/utils.js';
|
||||
|
||||
/**
|
||||
* Show MCP server approval dialogs for pending project servers.
|
||||
* Uses the provided Ink root to render (reusing the existing instance
|
||||
* from main.tsx instead of creating a separate one).
|
||||
*/
|
||||
export async function handleMcpjsonServerApprovals(root: Root): Promise<void> {
|
||||
const {
|
||||
servers: projectServers
|
||||
} = getMcpConfigsByScope('project');
|
||||
const pendingServers = Object.keys(projectServers).filter(serverName => getProjectMcpServerStatus(serverName) === 'pending');
|
||||
if (pendingServers.length === 0) {
|
||||
return;
|
||||
}
|
||||
await new Promise<void>(resolve => {
|
||||
const done = (): void => void resolve();
|
||||
if (pendingServers.length === 1 && pendingServers[0] !== undefined) {
|
||||
const serverName = pendingServers[0];
|
||||
root.render(<AppStateProvider>
|
||||
<KeybindingSetup>
|
||||
<MCPServerApprovalDialog serverName={serverName} onDone={done} />
|
||||
</KeybindingSetup>
|
||||
</AppStateProvider>);
|
||||
} else {
|
||||
root.render(<AppStateProvider>
|
||||
<KeybindingSetup>
|
||||
<MCPServerMultiselectDialog serverNames={pendingServers} onDone={done} />
|
||||
</KeybindingSetup>
|
||||
</AppStateProvider>);
|
||||
}
|
||||
});
|
||||
}
|
||||
//# sourceMappingURL=data:application/json;charset=utf-8;base64,eyJ2ZXJzaW9uIjozLCJuYW1lcyI6WyJSZWFjdCIsIk1DUFNlcnZlckFwcHJvdmFsRGlhbG9nIiwiTUNQU2VydmVyTXVsdGlzZWxlY3REaWFsb2ciLCJSb290IiwiS2V5YmluZGluZ1NldHVwIiwiQXBwU3RhdGVQcm92aWRlciIsImdldE1jcENvbmZpZ3NCeVNjb3BlIiwiZ2V0UHJvamVjdE1jcFNlcnZlclN0YXR1cyIsImhhbmRsZU1jcGpzb25TZXJ2ZXJBcHByb3ZhbHMiLCJyb290IiwiUHJvbWlzZSIsInNlcnZlcnMiLCJwcm9qZWN0U2VydmVycyIsInBlbmRpbmdTZXJ2ZXJzIiwiT2JqZWN0Iiwia2V5cyIsImZpbHRlciIsInNlcnZlck5hbWUiLCJsZW5ndGgiLCJyZXNvbHZlIiwiZG9uZSIsInVuZGVmaW5lZCIsInJlbmRlciJdLCJzb3VyY2VzIjpbIm1jcFNlcnZlckFwcHJvdmFsLnRzeCJdLCJzb3VyY2VzQ29udGVudCI6WyJpbXBvcnQgUmVhY3QgZnJvbSAncmVhY3QnXG5pbXBvcnQgeyBNQ1BTZXJ2ZXJBcHByb3ZhbERpYWxvZyB9IGZyb20gJy4uL2NvbXBvbmVudHMvTUNQU2VydmVyQXBwcm92YWxEaWFsb2cuanMnXG5pbXBvcnQgeyBNQ1BTZXJ2ZXJNdWx0aXNlbGVjdERpYWxvZyB9IGZyb20gJy4uL2NvbXBvbmVudHMvTUNQU2VydmVyTXVsdGlzZWxlY3REaWFsb2cuanMnXG5pbXBvcnQgdHlwZSB7IFJvb3QgfSBmcm9tICcuLi9pbmsuanMnXG5pbXBvcnQgeyBLZXliaW5kaW5nU2V0dXAgfSBmcm9tICcuLi9rZXliaW5kaW5ncy9LZXliaW5kaW5nUHJvdmlkZXJTZXR1cC5qcydcbmltcG9ydCB7IEFwcFN0YXRlUHJvdmlkZXIgfSBmcm9tICcuLi9zdGF0ZS9BcHBTdGF0ZS5qcydcbmltcG9ydCB7IGdldE1jcENvbmZpZ3NCeVNjb3BlIH0gZnJvbSAnLi9tY3AvY29uZmlnLmpzJ1xuaW1wb3J0IHsgZ2V0UHJvamVjdE1jcFNlcnZlclN0YXR1cyB9IGZyb20gJy4vbWNwL3V0aWxzLmpzJ1xuXG4vKipcbiAqIFNob3cgTUNQIHNlcnZlciBhcHByb3ZhbCBkaWFsb2dzIGZvciBwZW5kaW5nIHByb2plY3Qgc2VydmVycy5cbiAqIFVzZXMgdGhlIHByb3ZpZGVkIEluayByb290IHRvIHJlbmRlciAocmV1c2luZyB0aGUgZXhpc3RpbmcgaW5zdGFuY2VcbiAqIGZyb20gbWFpbi50c3ggaW5zdGVhZCBvZiBjcmVhdGluZyBhIHNlcGFyYXRlIG9uZSkuXG4gKi9cbmV4cG9ydCBhc3luYyBmdW5jdGlvbiBoYW5kbGVNY3Bqc29uU2VydmVyQXBwcm92YWxzKHJvb3Q6IFJvb3QpOiBQcm9taXNlPHZvaWQ+IHtcbiAgY29uc3QgeyBzZXJ2ZXJzOiBwcm9qZWN0U2VydmVycyB9ID0gZ2V0TWNwQ29uZmlnc0J5U2NvcGUoJ3Byb2plY3QnKVxuICBjb25zdCBwZW5kaW5nU2VydmVycyA9IE9iamVjdC5rZXlzKHByb2plY3RTZXJ2ZXJzKS5maWx0ZXIoXG4gICAgc2VydmVyTmFtZSA9PiBnZXRQcm9qZWN0TWNwU2VydmVyU3RhdHVzKHNlcnZlck5hbWUpID09PSAncGVuZGluZycsXG4gIClcblxuICBpZiAocGVuZGluZ1NlcnZlcnMubGVuZ3RoID09PSAwKSB7XG4gICAgcmV0dXJuXG4gIH1cblxuICBhd2FpdCBuZXcgUHJvbWlzZTx2b2lkPihyZXNvbHZlID0+IHtcbiAgICBjb25zdCBkb25lID0gKCk6IHZvaWQgPT4gdm9pZCByZXNvbHZlKClcbiAgICBpZiAocGVuZGluZ1NlcnZlcnMubGVuZ3RoID09PSAxICYmIHBlbmRpbmdTZXJ2ZXJzWzBdICE9PSB1bmRlZmluZWQpIHtcbiAgICAgIGNvbnN0IHNlcnZlck5hbWUgPSBwZW5kaW5nU2VydmVyc1swXVxuICAgICAgcm9vdC5yZW5kZXIoXG4gICAgICAgIDxBcHBTdGF0ZVByb3ZpZGVyPlxuICAgICAgICAgIDxLZXliaW5kaW5nU2V0dXA+XG4gICAgICAgICAgICA8TUNQU2VydmVyQXBwcm92YWxEaWFsb2cgc2VydmVyTmFtZT17c2VydmVyTmFtZX0gb25Eb25lPXtkb25lfSAvPlxuICAgICAgICAgIDwvS2V5YmluZGluZ1NldHVwPlxuICAgICAgICA8L0FwcFN0YXRlUHJvdmlkZXI+LFxuICAgICAgKVxuICAgIH0gZWxzZSB7XG4gICAgICByb290LnJlbmRlcihcbiAgICAgICAgPEFwcFN0YXRlUHJvdmlkZXI+XG4gICAgICAgICAgPEtleWJpbmRpbmdTZXR1cD5cbiAgICAgICAgICAgIDxNQ1BTZXJ2ZXJNdWx0aXNlbGVjdERpYWxvZ1xuICAgICAgICAgICAgICBzZXJ2ZXJOYW1lcz17cGVuZGluZ1NlcnZlcnN9XG4gICAgICAgICAgICAgIG9uRG9uZT17ZG9uZX1cbiAgICAgICAgICAgIC8+XG4gICAgICAgICAgPC9LZXliaW5kaW5nU2V0dXA+XG4gICAgICAgIDwvQXBwU3RhdGVQcm92aWRlcj4sXG4gICAgICApXG4gICAgfVxuICB9KVxufVxuIl0sIm1hcHBpbmdzIjoiQUFBQSxPQUFPQSxLQUFLLE1BQU0sT0FBTztBQUN6QixTQUFTQyx1QkFBdUIsUUFBUSwwQ0FBMEM7QUFDbEYsU0FBU0MsMEJBQTBCLFFBQVEsNkNBQTZDO0FBQ3hGLGNBQWNDLElBQUksUUFBUSxXQUFXO0FBQ3JDLFNBQVNDLGVBQWUsUUFBUSwyQ0FBMkM7QUFDM0UsU0FBU0MsZ0JBQWdCLFFBQVEsc0JBQXNCO0FBQ3ZELFNBQVNDLG9CQUFvQixRQUFRLGlCQUFpQjtBQUN0RCxTQUFTQyx5QkFBeUIsUUFBUSxnQkFBZ0I7O0FBRTFEO0FBQ0E7QUFDQTtBQUNBO0FBQ0E7QUFDQSxPQUFPLGVBQWVDLDRCQUE0QkEsQ0FBQ0MsSUFBSSxFQUFFTixJQUFJLENBQUMsRUFBRU8sT0FBTyxDQUFDLElBQUksQ0FBQyxDQUFDO0VBQzVFLE1BQU07SUFBRUMsT0FBTyxFQUFFQztFQUFlLENBQUMsR0FBR04sb0JBQW9CLENBQUMsU0FBUyxDQUFDO0VBQ25FLE1BQU1PLGNBQWMsR0FBR0MsTUFBTSxDQUFDQyxJQUFJLENBQUNILGNBQWMsQ0FBQyxDQUFDSSxNQUFNLENBQ3ZEQyxVQUFVLElBQUlWLHlCQUF5QixDQUFDVSxVQUFVLENBQUMsS0FBSyxTQUMxRCxDQUFDO0VBRUQsSUFBSUosY0FBYyxDQUFDSyxNQUFNLEtBQUssQ0FBQyxFQUFFO0lBQy9CO0VBQ0Y7RUFFQSxNQUFNLElBQUlSLE9BQU8sQ0FBQyxJQUFJLENBQUMsQ0FBQ1MsT0FBTyxJQUFJO0lBQ2pDLE1BQU1DLElBQUksR0FBR0EsQ0FBQSxDQUFFLEVBQUUsSUFBSSxJQUFJLEtBQUtELE9BQU8sQ0FBQyxDQUFDO0lBQ3ZDLElBQUlOLGNBQWMsQ0FBQ0ssTUFBTSxLQUFLLENBQUMsSUFBSUwsY0FBYyxDQUFDLENBQUMsQ0FBQyxLQUFLUSxTQUFTLEVBQUU7TUFDbEUsTUFBTUosVUFBVSxHQUFHSixjQUFjLENBQUMsQ0FBQyxDQUFDO01BQ3BDSixJQUFJLENBQUNhLE1BQU0sQ0FDVCxDQUFDLGdCQUFnQjtBQUN6QixVQUFVLENBQUMsZUFBZTtBQUMxQixZQUFZLENBQUMsdUJBQXVCLENBQUMsVUFBVSxDQUFDLENBQUNMLFVBQVUsQ0FBQyxDQUFDLE1BQU0sQ0FBQyxDQUFDRyxJQUFJLENBQUM7QUFDMUUsVUFBVSxFQUFFLGVBQWU7QUFDM0IsUUFBUSxFQUFFLGdCQUFnQixDQUNwQixDQUFDO0lBQ0gsQ0FBQyxNQUFNO01BQ0xYLElBQUksQ0FBQ2EsTUFBTSxDQUNULENBQUMsZ0JBQWdCO0FBQ3pCLFVBQVUsQ0FBQyxlQUFlO0FBQzFCLFlBQVksQ0FBQywwQkFBMEIsQ0FDekIsV0FBVyxDQUFDLENBQUNULGNBQWMsQ0FBQyxDQUM1QixNQUFNLENBQUMsQ0FBQ08sSUFBSSxDQUFDO0FBRTNCLFVBQVUsRUFBRSxlQUFlO0FBQzNCLFFBQVEsRUFBRSxnQkFBZ0IsQ0FDcEIsQ0FBQztJQUNIO0VBQ0YsQ0FBQyxDQUFDO0FBQ0oiLCJpZ25vcmVMaXN0IjpbXX0=
|
||||
882
src/services/mockRateLimits.ts
Normal file
882
src/services/mockRateLimits.ts
Normal file
@@ -0,0 +1,882 @@
|
||||
// Mock rate limits for testing [ANT-ONLY]
|
||||
// This allows testing various rate limit scenarios without hitting actual limits
|
||||
//
|
||||
// ⚠️ WARNING: This is for internal testing/demo purposes only!
|
||||
// The mock headers may not exactly match the API specification or real-world behavior.
|
||||
// Always validate against actual API responses before relying on this for production features.
|
||||
|
||||
import type { SubscriptionType } from '../services/oauth/types.js'
|
||||
import { setMockBillingAccessOverride } from '../utils/billing.js'
|
||||
import type { OverageDisabledReason } from './claudeAiLimits.js'
|
||||
|
||||
type MockHeaders = {
|
||||
'anthropic-ratelimit-unified-status'?:
|
||||
| 'allowed'
|
||||
| 'allowed_warning'
|
||||
| 'rejected'
|
||||
'anthropic-ratelimit-unified-reset'?: string
|
||||
'anthropic-ratelimit-unified-representative-claim'?:
|
||||
| 'five_hour'
|
||||
| 'seven_day'
|
||||
| 'seven_day_opus'
|
||||
| 'seven_day_sonnet'
|
||||
'anthropic-ratelimit-unified-overage-status'?:
|
||||
| 'allowed'
|
||||
| 'allowed_warning'
|
||||
| 'rejected'
|
||||
'anthropic-ratelimit-unified-overage-reset'?: string
|
||||
'anthropic-ratelimit-unified-overage-disabled-reason'?: OverageDisabledReason
|
||||
'anthropic-ratelimit-unified-fallback'?: 'available'
|
||||
'anthropic-ratelimit-unified-fallback-percentage'?: string
|
||||
'retry-after'?: string
|
||||
// Early warning utilization headers
|
||||
'anthropic-ratelimit-unified-5h-utilization'?: string
|
||||
'anthropic-ratelimit-unified-5h-reset'?: string
|
||||
'anthropic-ratelimit-unified-5h-surpassed-threshold'?: string
|
||||
'anthropic-ratelimit-unified-7d-utilization'?: string
|
||||
'anthropic-ratelimit-unified-7d-reset'?: string
|
||||
'anthropic-ratelimit-unified-7d-surpassed-threshold'?: string
|
||||
'anthropic-ratelimit-unified-overage-utilization'?: string
|
||||
'anthropic-ratelimit-unified-overage-surpassed-threshold'?: string
|
||||
}
|
||||
|
||||
export type MockHeaderKey =
|
||||
| 'status'
|
||||
| 'reset'
|
||||
| 'claim'
|
||||
| 'overage-status'
|
||||
| 'overage-reset'
|
||||
| 'overage-disabled-reason'
|
||||
| 'fallback'
|
||||
| 'fallback-percentage'
|
||||
| 'retry-after'
|
||||
| '5h-utilization'
|
||||
| '5h-reset'
|
||||
| '5h-surpassed-threshold'
|
||||
| '7d-utilization'
|
||||
| '7d-reset'
|
||||
| '7d-surpassed-threshold'
|
||||
|
||||
export type MockScenario =
|
||||
| 'normal'
|
||||
| 'session-limit-reached'
|
||||
| 'approaching-weekly-limit'
|
||||
| 'weekly-limit-reached'
|
||||
| 'overage-active'
|
||||
| 'overage-warning'
|
||||
| 'overage-exhausted'
|
||||
| 'out-of-credits'
|
||||
| 'org-zero-credit-limit'
|
||||
| 'org-spend-cap-hit'
|
||||
| 'member-zero-credit-limit'
|
||||
| 'seat-tier-zero-credit-limit'
|
||||
| 'opus-limit'
|
||||
| 'opus-warning'
|
||||
| 'sonnet-limit'
|
||||
| 'sonnet-warning'
|
||||
| 'fast-mode-limit'
|
||||
| 'fast-mode-short-limit'
|
||||
| 'extra-usage-required'
|
||||
| 'clear'
|
||||
|
||||
let mockHeaders: MockHeaders = {}
|
||||
let mockEnabled = false
|
||||
let mockHeaderless429Message: string | null = null
|
||||
let mockSubscriptionType: SubscriptionType | null = null
|
||||
let mockFastModeRateLimitDurationMs: number | null = null
|
||||
let mockFastModeRateLimitExpiresAt: number | null = null
|
||||
// Default subscription type for mock testing
|
||||
const DEFAULT_MOCK_SUBSCRIPTION: SubscriptionType = 'max'
|
||||
|
||||
// Track individual exceeded limits with their reset times
|
||||
type ExceededLimit = {
|
||||
type: 'five_hour' | 'seven_day' | 'seven_day_opus' | 'seven_day_sonnet'
|
||||
resetsAt: number // Unix timestamp
|
||||
}
|
||||
|
||||
let exceededLimits: ExceededLimit[] = []
|
||||
|
||||
// New approach: Toggle individual headers
|
||||
export function setMockHeader(
|
||||
key: MockHeaderKey,
|
||||
value: string | undefined,
|
||||
): void {
|
||||
if (process.env.USER_TYPE !== 'ant') {
|
||||
return
|
||||
}
|
||||
|
||||
mockEnabled = true
|
||||
|
||||
// Special case for retry-after which doesn't have the prefix
|
||||
const fullKey = (
|
||||
key === 'retry-after' ? 'retry-after' : `anthropic-ratelimit-unified-${key}`
|
||||
) as keyof MockHeaders
|
||||
|
||||
if (value === undefined || value === 'clear') {
|
||||
delete mockHeaders[fullKey]
|
||||
if (key === 'claim') {
|
||||
exceededLimits = []
|
||||
}
|
||||
// Update retry-after if status changed
|
||||
if (key === 'status' || key === 'overage-status') {
|
||||
updateRetryAfter()
|
||||
}
|
||||
return
|
||||
} else {
|
||||
// Handle special cases for reset times
|
||||
if (key === 'reset' || key === 'overage-reset') {
|
||||
// If user provides a number, treat it as hours from now
|
||||
const hours = Number(value)
|
||||
if (!isNaN(hours)) {
|
||||
value = String(Math.floor(Date.now() / 1000) + hours * 3600)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle claims - add to exceeded limits
|
||||
if (key === 'claim') {
|
||||
const validClaims = [
|
||||
'five_hour',
|
||||
'seven_day',
|
||||
'seven_day_opus',
|
||||
'seven_day_sonnet',
|
||||
]
|
||||
if (validClaims.includes(value)) {
|
||||
// Determine reset time based on claim type
|
||||
let resetsAt: number
|
||||
if (value === 'five_hour') {
|
||||
resetsAt = Math.floor(Date.now() / 1000) + 5 * 3600
|
||||
} else if (
|
||||
value === 'seven_day' ||
|
||||
value === 'seven_day_opus' ||
|
||||
value === 'seven_day_sonnet'
|
||||
) {
|
||||
resetsAt = Math.floor(Date.now() / 1000) + 7 * 24 * 3600
|
||||
} else {
|
||||
resetsAt = Math.floor(Date.now() / 1000) + 3600
|
||||
}
|
||||
|
||||
// Add to exceeded limits (remove if already exists)
|
||||
exceededLimits = exceededLimits.filter(l => l.type !== value)
|
||||
exceededLimits.push({ type: value as ExceededLimit['type'], resetsAt })
|
||||
|
||||
// Set the representative claim (furthest reset time)
|
||||
updateRepresentativeClaim()
|
||||
return
|
||||
}
|
||||
}
|
||||
// Widen to a string-valued record so dynamic key assignment is allowed.
|
||||
// MockHeaders values are string-literal unions; assigning a raw user-input
|
||||
// string requires widening, but this is mock/test code so it's acceptable.
|
||||
const headers: Partial<Record<keyof MockHeaders, string>> = mockHeaders
|
||||
headers[fullKey] = value
|
||||
|
||||
// Update retry-after if status changed
|
||||
if (key === 'status' || key === 'overage-status') {
|
||||
updateRetryAfter()
|
||||
}
|
||||
}
|
||||
|
||||
// If all headers are cleared, disable mocking
|
||||
if (Object.keys(mockHeaders).length === 0) {
|
||||
mockEnabled = false
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to update retry-after based on current state
|
||||
function updateRetryAfter(): void {
|
||||
const status = mockHeaders['anthropic-ratelimit-unified-status']
|
||||
const overageStatus =
|
||||
mockHeaders['anthropic-ratelimit-unified-overage-status']
|
||||
const reset = mockHeaders['anthropic-ratelimit-unified-reset']
|
||||
|
||||
if (
|
||||
status === 'rejected' &&
|
||||
(!overageStatus || overageStatus === 'rejected') &&
|
||||
reset
|
||||
) {
|
||||
// Calculate seconds until reset
|
||||
const resetTimestamp = Number(reset)
|
||||
const secondsUntilReset = Math.max(
|
||||
0,
|
||||
resetTimestamp - Math.floor(Date.now() / 1000),
|
||||
)
|
||||
mockHeaders['retry-after'] = String(secondsUntilReset)
|
||||
} else {
|
||||
delete mockHeaders['retry-after']
|
||||
}
|
||||
}
|
||||
|
||||
// Update the representative claim based on exceeded limits
|
||||
function updateRepresentativeClaim(): void {
|
||||
if (exceededLimits.length === 0) {
|
||||
delete mockHeaders['anthropic-ratelimit-unified-representative-claim']
|
||||
delete mockHeaders['anthropic-ratelimit-unified-reset']
|
||||
delete mockHeaders['retry-after']
|
||||
return
|
||||
}
|
||||
|
||||
// Find the limit with the furthest reset time
|
||||
const furthest = exceededLimits.reduce((prev, curr) =>
|
||||
curr.resetsAt > prev.resetsAt ? curr : prev,
|
||||
)
|
||||
|
||||
// Set the representative claim (appears for both warning and rejected)
|
||||
mockHeaders['anthropic-ratelimit-unified-representative-claim'] =
|
||||
furthest.type
|
||||
mockHeaders['anthropic-ratelimit-unified-reset'] = String(furthest.resetsAt)
|
||||
|
||||
// Add retry-after if rejected and no overage available
|
||||
if (mockHeaders['anthropic-ratelimit-unified-status'] === 'rejected') {
|
||||
const overageStatus =
|
||||
mockHeaders['anthropic-ratelimit-unified-overage-status']
|
||||
if (!overageStatus || overageStatus === 'rejected') {
|
||||
// Calculate seconds until reset
|
||||
const secondsUntilReset = Math.max(
|
||||
0,
|
||||
furthest.resetsAt - Math.floor(Date.now() / 1000),
|
||||
)
|
||||
mockHeaders['retry-after'] = String(secondsUntilReset)
|
||||
} else {
|
||||
// Overage is available, no retry-after
|
||||
delete mockHeaders['retry-after']
|
||||
}
|
||||
} else {
|
||||
delete mockHeaders['retry-after']
|
||||
}
|
||||
}
|
||||
|
||||
// Add function to add exceeded limit with custom reset time
|
||||
export function addExceededLimit(
|
||||
type: 'five_hour' | 'seven_day' | 'seven_day_opus' | 'seven_day_sonnet',
|
||||
hoursFromNow: number,
|
||||
): void {
|
||||
if (process.env.USER_TYPE !== 'ant') {
|
||||
return
|
||||
}
|
||||
|
||||
mockEnabled = true
|
||||
const resetsAt = Math.floor(Date.now() / 1000) + hoursFromNow * 3600
|
||||
|
||||
// Remove existing limit of same type
|
||||
exceededLimits = exceededLimits.filter(l => l.type !== type)
|
||||
exceededLimits.push({ type, resetsAt })
|
||||
|
||||
// Update status to rejected if we have exceeded limits
|
||||
if (exceededLimits.length > 0) {
|
||||
mockHeaders['anthropic-ratelimit-unified-status'] = 'rejected'
|
||||
}
|
||||
|
||||
updateRepresentativeClaim()
|
||||
}
|
||||
|
||||
// Set mock early warning utilization for time-relative thresholds
|
||||
// claimAbbrev: '5h' or '7d'
|
||||
// utilization: 0-1 (e.g., 0.92 for 92% used)
|
||||
// hoursFromNow: hours until reset (default: 4 for 5h, 120 for 7d)
|
||||
export function setMockEarlyWarning(
|
||||
claimAbbrev: '5h' | '7d' | 'overage',
|
||||
utilization: number,
|
||||
hoursFromNow?: number,
|
||||
): void {
|
||||
if (process.env.USER_TYPE !== 'ant') {
|
||||
return
|
||||
}
|
||||
|
||||
mockEnabled = true
|
||||
|
||||
// Clear ALL early warning headers first (5h is checked before 7d, so we need
|
||||
// to clear 5h headers when testing 7d to avoid 5h taking priority)
|
||||
clearMockEarlyWarning()
|
||||
|
||||
// Default hours based on claim type (early in window to trigger warning)
|
||||
const defaultHours = claimAbbrev === '5h' ? 4 : 5 * 24
|
||||
const hours = hoursFromNow ?? defaultHours
|
||||
const resetsAt = Math.floor(Date.now() / 1000) + hours * 3600
|
||||
|
||||
mockHeaders[`anthropic-ratelimit-unified-${claimAbbrev}-utilization`] =
|
||||
String(utilization)
|
||||
mockHeaders[`anthropic-ratelimit-unified-${claimAbbrev}-reset`] =
|
||||
String(resetsAt)
|
||||
// Set the surpassed-threshold header to trigger early warning
|
||||
mockHeaders[
|
||||
`anthropic-ratelimit-unified-${claimAbbrev}-surpassed-threshold`
|
||||
] = String(utilization)
|
||||
|
||||
// Set status to allowed so early warning logic can upgrade it
|
||||
if (!mockHeaders['anthropic-ratelimit-unified-status']) {
|
||||
mockHeaders['anthropic-ratelimit-unified-status'] = 'allowed'
|
||||
}
|
||||
}
|
||||
|
||||
// Clear mock early warning headers
|
||||
export function clearMockEarlyWarning(): void {
|
||||
delete mockHeaders['anthropic-ratelimit-unified-5h-utilization']
|
||||
delete mockHeaders['anthropic-ratelimit-unified-5h-reset']
|
||||
delete mockHeaders['anthropic-ratelimit-unified-5h-surpassed-threshold']
|
||||
delete mockHeaders['anthropic-ratelimit-unified-7d-utilization']
|
||||
delete mockHeaders['anthropic-ratelimit-unified-7d-reset']
|
||||
delete mockHeaders['anthropic-ratelimit-unified-7d-surpassed-threshold']
|
||||
}
|
||||
|
||||
export function setMockRateLimitScenario(scenario: MockScenario): void {
|
||||
if (process.env.USER_TYPE !== 'ant') {
|
||||
return
|
||||
}
|
||||
|
||||
if (scenario === 'clear') {
|
||||
mockHeaders = {}
|
||||
mockHeaderless429Message = null
|
||||
mockEnabled = false
|
||||
return
|
||||
}
|
||||
|
||||
mockEnabled = true
|
||||
|
||||
// Set reset times for demos
|
||||
const fiveHoursFromNow = Math.floor(Date.now() / 1000) + 5 * 3600
|
||||
const sevenDaysFromNow = Math.floor(Date.now() / 1000) + 7 * 24 * 3600
|
||||
|
||||
// Clear existing headers
|
||||
mockHeaders = {}
|
||||
mockHeaderless429Message = null
|
||||
|
||||
// Only clear exceeded limits for scenarios that explicitly set them
|
||||
// Overage scenarios should preserve existing exceeded limits
|
||||
const preserveExceededLimits = [
|
||||
'overage-active',
|
||||
'overage-warning',
|
||||
'overage-exhausted',
|
||||
].includes(scenario)
|
||||
if (!preserveExceededLimits) {
|
||||
exceededLimits = []
|
||||
}
|
||||
|
||||
switch (scenario) {
|
||||
case 'normal':
|
||||
mockHeaders = {
|
||||
'anthropic-ratelimit-unified-status': 'allowed',
|
||||
'anthropic-ratelimit-unified-reset': String(fiveHoursFromNow),
|
||||
}
|
||||
break
|
||||
|
||||
case 'session-limit-reached':
|
||||
exceededLimits = [{ type: 'five_hour', resetsAt: fiveHoursFromNow }]
|
||||
updateRepresentativeClaim()
|
||||
mockHeaders['anthropic-ratelimit-unified-status'] = 'rejected'
|
||||
break
|
||||
|
||||
case 'approaching-weekly-limit':
|
||||
mockHeaders = {
|
||||
'anthropic-ratelimit-unified-status': 'allowed_warning',
|
||||
'anthropic-ratelimit-unified-reset': String(sevenDaysFromNow),
|
||||
'anthropic-ratelimit-unified-representative-claim': 'seven_day',
|
||||
}
|
||||
break
|
||||
|
||||
case 'weekly-limit-reached':
|
||||
exceededLimits = [{ type: 'seven_day', resetsAt: sevenDaysFromNow }]
|
||||
updateRepresentativeClaim()
|
||||
mockHeaders['anthropic-ratelimit-unified-status'] = 'rejected'
|
||||
break
|
||||
|
||||
case 'overage-active': {
|
||||
// If no limits have been exceeded yet, default to 5-hour
|
||||
if (exceededLimits.length === 0) {
|
||||
exceededLimits = [{ type: 'five_hour', resetsAt: fiveHoursFromNow }]
|
||||
}
|
||||
updateRepresentativeClaim()
|
||||
mockHeaders['anthropic-ratelimit-unified-status'] = 'rejected'
|
||||
mockHeaders['anthropic-ratelimit-unified-overage-status'] = 'allowed'
|
||||
// Set overage reset time (monthly)
|
||||
const endOfMonthActive = new Date()
|
||||
endOfMonthActive.setMonth(endOfMonthActive.getMonth() + 1, 1)
|
||||
endOfMonthActive.setHours(0, 0, 0, 0)
|
||||
mockHeaders['anthropic-ratelimit-unified-overage-reset'] = String(
|
||||
Math.floor(endOfMonthActive.getTime() / 1000),
|
||||
)
|
||||
break
|
||||
}
|
||||
|
||||
case 'overage-warning': {
|
||||
// If no limits have been exceeded yet, default to 5-hour
|
||||
if (exceededLimits.length === 0) {
|
||||
exceededLimits = [{ type: 'five_hour', resetsAt: fiveHoursFromNow }]
|
||||
}
|
||||
updateRepresentativeClaim()
|
||||
mockHeaders['anthropic-ratelimit-unified-status'] = 'rejected'
|
||||
mockHeaders['anthropic-ratelimit-unified-overage-status'] =
|
||||
'allowed_warning'
|
||||
// Overage typically resets monthly, but for demo let's say end of month
|
||||
const endOfMonth = new Date()
|
||||
endOfMonth.setMonth(endOfMonth.getMonth() + 1, 1)
|
||||
endOfMonth.setHours(0, 0, 0, 0)
|
||||
mockHeaders['anthropic-ratelimit-unified-overage-reset'] = String(
|
||||
Math.floor(endOfMonth.getTime() / 1000),
|
||||
)
|
||||
break
|
||||
}
|
||||
|
||||
case 'overage-exhausted': {
|
||||
// If no limits have been exceeded yet, default to 5-hour
|
||||
if (exceededLimits.length === 0) {
|
||||
exceededLimits = [{ type: 'five_hour', resetsAt: fiveHoursFromNow }]
|
||||
}
|
||||
updateRepresentativeClaim()
|
||||
mockHeaders['anthropic-ratelimit-unified-status'] = 'rejected'
|
||||
mockHeaders['anthropic-ratelimit-unified-overage-status'] = 'rejected'
|
||||
// Both subscription and overage are exhausted
|
||||
// Subscription resets based on the exceeded limit, overage resets monthly
|
||||
const endOfMonthExhausted = new Date()
|
||||
endOfMonthExhausted.setMonth(endOfMonthExhausted.getMonth() + 1, 1)
|
||||
endOfMonthExhausted.setHours(0, 0, 0, 0)
|
||||
mockHeaders['anthropic-ratelimit-unified-overage-reset'] = String(
|
||||
Math.floor(endOfMonthExhausted.getTime() / 1000),
|
||||
)
|
||||
break
|
||||
}
|
||||
|
||||
case 'out-of-credits': {
|
||||
// Out of credits - subscription limit hit, overage rejected due to insufficient credits
|
||||
// (wallet is empty)
|
||||
if (exceededLimits.length === 0) {
|
||||
exceededLimits = [{ type: 'five_hour', resetsAt: fiveHoursFromNow }]
|
||||
}
|
||||
updateRepresentativeClaim()
|
||||
mockHeaders['anthropic-ratelimit-unified-status'] = 'rejected'
|
||||
mockHeaders['anthropic-ratelimit-unified-overage-status'] = 'rejected'
|
||||
mockHeaders['anthropic-ratelimit-unified-overage-disabled-reason'] =
|
||||
'out_of_credits'
|
||||
const endOfMonth = new Date()
|
||||
endOfMonth.setMonth(endOfMonth.getMonth() + 1, 1)
|
||||
endOfMonth.setHours(0, 0, 0, 0)
|
||||
mockHeaders['anthropic-ratelimit-unified-overage-reset'] = String(
|
||||
Math.floor(endOfMonth.getTime() / 1000),
|
||||
)
|
||||
break
|
||||
}
|
||||
|
||||
case 'org-zero-credit-limit': {
|
||||
// Org service has zero credit limit - admin set org-level spend cap to $0
|
||||
// Non-admin Team/Enterprise users should not see "Request extra usage" option
|
||||
if (exceededLimits.length === 0) {
|
||||
exceededLimits = [{ type: 'five_hour', resetsAt: fiveHoursFromNow }]
|
||||
}
|
||||
updateRepresentativeClaim()
|
||||
mockHeaders['anthropic-ratelimit-unified-status'] = 'rejected'
|
||||
mockHeaders['anthropic-ratelimit-unified-overage-status'] = 'rejected'
|
||||
mockHeaders['anthropic-ratelimit-unified-overage-disabled-reason'] =
|
||||
'org_service_zero_credit_limit'
|
||||
const endOfMonthZero = new Date()
|
||||
endOfMonthZero.setMonth(endOfMonthZero.getMonth() + 1, 1)
|
||||
endOfMonthZero.setHours(0, 0, 0, 0)
|
||||
mockHeaders['anthropic-ratelimit-unified-overage-reset'] = String(
|
||||
Math.floor(endOfMonthZero.getTime() / 1000),
|
||||
)
|
||||
break
|
||||
}
|
||||
|
||||
case 'org-spend-cap-hit': {
|
||||
// Org spend cap hit for the month - org overages temporarily disabled
|
||||
// Non-admin Team/Enterprise users should not see "Request extra usage" option
|
||||
if (exceededLimits.length === 0) {
|
||||
exceededLimits = [{ type: 'five_hour', resetsAt: fiveHoursFromNow }]
|
||||
}
|
||||
updateRepresentativeClaim()
|
||||
mockHeaders['anthropic-ratelimit-unified-status'] = 'rejected'
|
||||
mockHeaders['anthropic-ratelimit-unified-overage-status'] = 'rejected'
|
||||
mockHeaders['anthropic-ratelimit-unified-overage-disabled-reason'] =
|
||||
'org_level_disabled_until'
|
||||
const endOfMonthHit = new Date()
|
||||
endOfMonthHit.setMonth(endOfMonthHit.getMonth() + 1, 1)
|
||||
endOfMonthHit.setHours(0, 0, 0, 0)
|
||||
mockHeaders['anthropic-ratelimit-unified-overage-reset'] = String(
|
||||
Math.floor(endOfMonthHit.getTime() / 1000),
|
||||
)
|
||||
break
|
||||
}
|
||||
|
||||
case 'member-zero-credit-limit': {
|
||||
// Member has zero credit limit - admin set this user's individual limit to $0
|
||||
// Non-admin Team/Enterprise users SHOULD see "Request extra usage" (admin can allocate more)
|
||||
if (exceededLimits.length === 0) {
|
||||
exceededLimits = [{ type: 'five_hour', resetsAt: fiveHoursFromNow }]
|
||||
}
|
||||
updateRepresentativeClaim()
|
||||
mockHeaders['anthropic-ratelimit-unified-status'] = 'rejected'
|
||||
mockHeaders['anthropic-ratelimit-unified-overage-status'] = 'rejected'
|
||||
mockHeaders['anthropic-ratelimit-unified-overage-disabled-reason'] =
|
||||
'member_zero_credit_limit'
|
||||
const endOfMonthMember = new Date()
|
||||
endOfMonthMember.setMonth(endOfMonthMember.getMonth() + 1, 1)
|
||||
endOfMonthMember.setHours(0, 0, 0, 0)
|
||||
mockHeaders['anthropic-ratelimit-unified-overage-reset'] = String(
|
||||
Math.floor(endOfMonthMember.getTime() / 1000),
|
||||
)
|
||||
break
|
||||
}
|
||||
|
||||
case 'seat-tier-zero-credit-limit': {
|
||||
// Seat tier has zero credit limit - admin set this seat tier's limit to $0
|
||||
// Non-admin Team/Enterprise users SHOULD see "Request extra usage" (admin can allocate more)
|
||||
if (exceededLimits.length === 0) {
|
||||
exceededLimits = [{ type: 'five_hour', resetsAt: fiveHoursFromNow }]
|
||||
}
|
||||
updateRepresentativeClaim()
|
||||
mockHeaders['anthropic-ratelimit-unified-status'] = 'rejected'
|
||||
mockHeaders['anthropic-ratelimit-unified-overage-status'] = 'rejected'
|
||||
mockHeaders['anthropic-ratelimit-unified-overage-disabled-reason'] =
|
||||
'seat_tier_zero_credit_limit'
|
||||
const endOfMonthSeatTier = new Date()
|
||||
endOfMonthSeatTier.setMonth(endOfMonthSeatTier.getMonth() + 1, 1)
|
||||
endOfMonthSeatTier.setHours(0, 0, 0, 0)
|
||||
mockHeaders['anthropic-ratelimit-unified-overage-reset'] = String(
|
||||
Math.floor(endOfMonthSeatTier.getTime() / 1000),
|
||||
)
|
||||
break
|
||||
}
|
||||
|
||||
case 'opus-limit': {
|
||||
exceededLimits = [{ type: 'seven_day_opus', resetsAt: sevenDaysFromNow }]
|
||||
updateRepresentativeClaim()
|
||||
// Always send 429 rejected status - the error handler will decide whether
|
||||
// to show an error or return NO_RESPONSE_REQUESTED based on fallback eligibility
|
||||
mockHeaders['anthropic-ratelimit-unified-status'] = 'rejected'
|
||||
break
|
||||
}
|
||||
|
||||
case 'opus-warning': {
|
||||
mockHeaders = {
|
||||
'anthropic-ratelimit-unified-status': 'allowed_warning',
|
||||
'anthropic-ratelimit-unified-reset': String(sevenDaysFromNow),
|
||||
'anthropic-ratelimit-unified-representative-claim': 'seven_day_opus',
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
case 'sonnet-limit': {
|
||||
exceededLimits = [
|
||||
{ type: 'seven_day_sonnet', resetsAt: sevenDaysFromNow },
|
||||
]
|
||||
updateRepresentativeClaim()
|
||||
mockHeaders['anthropic-ratelimit-unified-status'] = 'rejected'
|
||||
break
|
||||
}
|
||||
|
||||
case 'sonnet-warning': {
|
||||
mockHeaders = {
|
||||
'anthropic-ratelimit-unified-status': 'allowed_warning',
|
||||
'anthropic-ratelimit-unified-reset': String(sevenDaysFromNow),
|
||||
'anthropic-ratelimit-unified-representative-claim': 'seven_day_sonnet',
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
case 'fast-mode-limit': {
|
||||
updateRepresentativeClaim()
|
||||
mockHeaders['anthropic-ratelimit-unified-status'] = 'rejected'
|
||||
// Duration in ms (> 20s threshold to trigger cooldown)
|
||||
mockFastModeRateLimitDurationMs = 10 * 60 * 1000
|
||||
break
|
||||
}
|
||||
|
||||
case 'fast-mode-short-limit': {
|
||||
updateRepresentativeClaim()
|
||||
mockHeaders['anthropic-ratelimit-unified-status'] = 'rejected'
|
||||
// Duration in ms (< 20s threshold, won't trigger cooldown)
|
||||
mockFastModeRateLimitDurationMs = 10 * 1000
|
||||
break
|
||||
}
|
||||
|
||||
case 'extra-usage-required': {
|
||||
// Headerless 429 — exercises the entitlement-rejection path in errors.ts
|
||||
mockHeaderless429Message =
|
||||
'Extra usage is required for long context requests.'
|
||||
break
|
||||
}
|
||||
|
||||
default:
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
export function getMockHeaderless429Message(): string | null {
|
||||
if (process.env.USER_TYPE !== 'ant') {
|
||||
return null
|
||||
}
|
||||
// Env var path for -p / SDK testing where slash commands aren't available
|
||||
if (process.env.CLAUDE_MOCK_HEADERLESS_429) {
|
||||
return process.env.CLAUDE_MOCK_HEADERLESS_429
|
||||
}
|
||||
if (!mockEnabled) {
|
||||
return null
|
||||
}
|
||||
return mockHeaderless429Message
|
||||
}
|
||||
|
||||
export function getMockHeaders(): MockHeaders | null {
|
||||
if (
|
||||
!mockEnabled ||
|
||||
process.env.USER_TYPE !== 'ant' ||
|
||||
Object.keys(mockHeaders).length === 0
|
||||
) {
|
||||
return null
|
||||
}
|
||||
return mockHeaders
|
||||
}
|
||||
|
||||
export function getMockStatus(): string {
|
||||
if (
|
||||
!mockEnabled ||
|
||||
(Object.keys(mockHeaders).length === 0 && !mockSubscriptionType)
|
||||
) {
|
||||
return 'No mock headers active (using real limits)'
|
||||
}
|
||||
|
||||
const lines: string[] = []
|
||||
lines.push('Active mock headers:')
|
||||
|
||||
// Show subscription type - either explicitly set or default
|
||||
const effectiveSubscription =
|
||||
mockSubscriptionType || DEFAULT_MOCK_SUBSCRIPTION
|
||||
if (mockSubscriptionType) {
|
||||
lines.push(` Subscription Type: ${mockSubscriptionType} (explicitly set)`)
|
||||
} else {
|
||||
lines.push(` Subscription Type: ${effectiveSubscription} (default)`)
|
||||
}
|
||||
|
||||
Object.entries(mockHeaders).forEach(([key, value]) => {
|
||||
if (value !== undefined) {
|
||||
// Format the header name nicely
|
||||
const formattedKey = key
|
||||
.replace('anthropic-ratelimit-unified-', '')
|
||||
.replace(/-/g, ' ')
|
||||
.replace(/\b\w/g, c => c.toUpperCase())
|
||||
|
||||
// Format timestamps as human-readable
|
||||
if (key.includes('reset') && value) {
|
||||
const timestamp = Number(value)
|
||||
const date = new Date(timestamp * 1000)
|
||||
lines.push(` ${formattedKey}: ${value} (${date.toLocaleString()})`)
|
||||
} else {
|
||||
lines.push(` ${formattedKey}: ${value}`)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Show exceeded limits if any
|
||||
if (exceededLimits.length > 0) {
|
||||
lines.push('\nExceeded limits (contributing to representative claim):')
|
||||
exceededLimits.forEach(limit => {
|
||||
const date = new Date(limit.resetsAt * 1000)
|
||||
lines.push(` ${limit.type}: resets at ${date.toLocaleString()}`)
|
||||
})
|
||||
}
|
||||
|
||||
return lines.join('\n')
|
||||
}
|
||||
|
||||
export function clearMockHeaders(): void {
|
||||
mockHeaders = {}
|
||||
exceededLimits = []
|
||||
mockSubscriptionType = null
|
||||
mockFastModeRateLimitDurationMs = null
|
||||
mockFastModeRateLimitExpiresAt = null
|
||||
mockHeaderless429Message = null
|
||||
setMockBillingAccessOverride(null)
|
||||
mockEnabled = false
|
||||
}
|
||||
|
||||
export function applyMockHeaders(
|
||||
headers: globalThis.Headers,
|
||||
): globalThis.Headers {
|
||||
const mock = getMockHeaders()
|
||||
if (!mock) {
|
||||
return headers
|
||||
}
|
||||
|
||||
// Create a new Headers object with original headers
|
||||
// eslint-disable-next-line eslint-plugin-n/no-unsupported-features/node-builtins
|
||||
const newHeaders = new globalThis.Headers(headers)
|
||||
|
||||
// Apply mock headers (overwriting originals)
|
||||
Object.entries(mock).forEach(([key, value]) => {
|
||||
if (value !== undefined) {
|
||||
newHeaders.set(key, value)
|
||||
}
|
||||
})
|
||||
|
||||
return newHeaders
|
||||
}
|
||||
|
||||
// Check if we should process rate limits even without subscription
|
||||
// This is for Ant employees testing with mocks
|
||||
export function shouldProcessMockLimits(): boolean {
|
||||
if (process.env.USER_TYPE !== 'ant') {
|
||||
return false
|
||||
}
|
||||
return mockEnabled || Boolean(process.env.CLAUDE_MOCK_HEADERLESS_429)
|
||||
}
|
||||
|
||||
export function getCurrentMockScenario(): MockScenario | null {
|
||||
if (!mockEnabled) {
|
||||
return null
|
||||
}
|
||||
|
||||
// Reverse lookup the scenario from current headers
|
||||
if (!mockHeaders) return null
|
||||
|
||||
const status = mockHeaders['anthropic-ratelimit-unified-status']
|
||||
const overage = mockHeaders['anthropic-ratelimit-unified-overage-status']
|
||||
const claim = mockHeaders['anthropic-ratelimit-unified-representative-claim']
|
||||
|
||||
if (claim === 'seven_day_opus') {
|
||||
return status === 'rejected' ? 'opus-limit' : 'opus-warning'
|
||||
}
|
||||
|
||||
if (claim === 'seven_day_sonnet') {
|
||||
return status === 'rejected' ? 'sonnet-limit' : 'sonnet-warning'
|
||||
}
|
||||
|
||||
if (overage === 'rejected') return 'overage-exhausted'
|
||||
if (overage === 'allowed_warning') return 'overage-warning'
|
||||
if (overage === 'allowed') return 'overage-active'
|
||||
|
||||
if (status === 'rejected') {
|
||||
if (claim === 'five_hour') return 'session-limit-reached'
|
||||
if (claim === 'seven_day') return 'weekly-limit-reached'
|
||||
}
|
||||
|
||||
if (status === 'allowed_warning') {
|
||||
if (claim === 'seven_day') return 'approaching-weekly-limit'
|
||||
}
|
||||
|
||||
if (status === 'allowed') return 'normal'
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
export function getScenarioDescription(scenario: MockScenario): string {
|
||||
switch (scenario) {
|
||||
case 'normal':
|
||||
return 'Normal usage, no limits'
|
||||
case 'session-limit-reached':
|
||||
return 'Session rate limit exceeded'
|
||||
case 'approaching-weekly-limit':
|
||||
return 'Approaching weekly aggregate limit'
|
||||
case 'weekly-limit-reached':
|
||||
return 'Weekly aggregate limit exceeded'
|
||||
case 'overage-active':
|
||||
return 'Using extra usage (overage active)'
|
||||
case 'overage-warning':
|
||||
return 'Approaching extra usage limit'
|
||||
case 'overage-exhausted':
|
||||
return 'Both subscription and extra usage limits exhausted'
|
||||
case 'out-of-credits':
|
||||
return 'Out of extra usage credits (wallet empty)'
|
||||
case 'org-zero-credit-limit':
|
||||
return 'Org spend cap is zero (no extra usage budget)'
|
||||
case 'org-spend-cap-hit':
|
||||
return 'Org spend cap hit for the month'
|
||||
case 'member-zero-credit-limit':
|
||||
return 'Member limit is zero (admin can allocate more)'
|
||||
case 'seat-tier-zero-credit-limit':
|
||||
return 'Seat tier limit is zero (admin can allocate more)'
|
||||
case 'opus-limit':
|
||||
return 'Opus limit reached'
|
||||
case 'opus-warning':
|
||||
return 'Approaching Opus limit'
|
||||
case 'sonnet-limit':
|
||||
return 'Sonnet limit reached'
|
||||
case 'sonnet-warning':
|
||||
return 'Approaching Sonnet limit'
|
||||
case 'fast-mode-limit':
|
||||
return 'Fast mode rate limit'
|
||||
case 'fast-mode-short-limit':
|
||||
return 'Fast mode rate limit (short)'
|
||||
case 'extra-usage-required':
|
||||
return 'Headerless 429: Extra usage required for 1M context'
|
||||
case 'clear':
|
||||
return 'Clear mock headers (use real limits)'
|
||||
default:
|
||||
return 'Unknown scenario'
|
||||
}
|
||||
}
|
||||
|
||||
// Mock subscription type management
|
||||
export function setMockSubscriptionType(
|
||||
subscriptionType: SubscriptionType | null,
|
||||
): void {
|
||||
if (process.env.USER_TYPE !== 'ant') {
|
||||
return
|
||||
}
|
||||
mockEnabled = true
|
||||
mockSubscriptionType = subscriptionType
|
||||
}
|
||||
|
||||
export function getMockSubscriptionType(): SubscriptionType | null {
|
||||
if (!mockEnabled || process.env.USER_TYPE !== 'ant') {
|
||||
return null
|
||||
}
|
||||
// Return the explicitly set subscription type, or default to 'max'
|
||||
return mockSubscriptionType || DEFAULT_MOCK_SUBSCRIPTION
|
||||
}
|
||||
|
||||
// Export a function that checks if we should use mock subscription
|
||||
export function shouldUseMockSubscription(): boolean {
|
||||
return (
|
||||
mockEnabled &&
|
||||
mockSubscriptionType !== null &&
|
||||
process.env.USER_TYPE === 'ant'
|
||||
)
|
||||
}
|
||||
|
||||
// Mock billing access (admin vs non-admin)
|
||||
export function setMockBillingAccess(hasAccess: boolean | null): void {
|
||||
if (process.env.USER_TYPE !== 'ant') {
|
||||
return
|
||||
}
|
||||
mockEnabled = true
|
||||
setMockBillingAccessOverride(hasAccess)
|
||||
}
|
||||
|
||||
// Mock fast mode rate limit handling
|
||||
export function isMockFastModeRateLimitScenario(): boolean {
|
||||
return mockFastModeRateLimitDurationMs !== null
|
||||
}
|
||||
|
||||
export function checkMockFastModeRateLimit(
|
||||
isFastModeActive?: boolean,
|
||||
): MockHeaders | null {
|
||||
if (mockFastModeRateLimitDurationMs === null) {
|
||||
return null
|
||||
}
|
||||
|
||||
// Only throw when fast mode is active
|
||||
if (!isFastModeActive) {
|
||||
return null
|
||||
}
|
||||
|
||||
// Check if the rate limit has expired
|
||||
if (
|
||||
mockFastModeRateLimitExpiresAt !== null &&
|
||||
Date.now() >= mockFastModeRateLimitExpiresAt
|
||||
) {
|
||||
clearMockHeaders()
|
||||
return null
|
||||
}
|
||||
|
||||
// Set expiry on first error (not when scenario is configured)
|
||||
if (mockFastModeRateLimitExpiresAt === null) {
|
||||
mockFastModeRateLimitExpiresAt =
|
||||
Date.now() + mockFastModeRateLimitDurationMs
|
||||
}
|
||||
|
||||
// Compute dynamic retry-after based on remaining time
|
||||
const remainingMs = mockFastModeRateLimitExpiresAt - Date.now()
|
||||
const headersToSend = { ...mockHeaders }
|
||||
headersToSend['retry-after'] = String(
|
||||
Math.max(1, Math.ceil(remainingMs / 1000)),
|
||||
)
|
||||
|
||||
return headersToSend
|
||||
}
|
||||
156
src/services/notifier.ts
Normal file
156
src/services/notifier.ts
Normal file
@@ -0,0 +1,156 @@
|
||||
import type { TerminalNotification } from '../ink/useTerminalNotification.js'
|
||||
import { getGlobalConfig } from '../utils/config.js'
|
||||
import { env } from '../utils/env.js'
|
||||
import { execFileNoThrow } from '../utils/execFileNoThrow.js'
|
||||
import { executeNotificationHooks } from '../utils/hooks.js'
|
||||
import { logError } from '../utils/log.js'
|
||||
import {
|
||||
type AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
logEvent,
|
||||
} from './analytics/index.js'
|
||||
|
||||
export type NotificationOptions = {
|
||||
message: string
|
||||
title?: string
|
||||
notificationType: string
|
||||
}
|
||||
|
||||
export async function sendNotification(
|
||||
notif: NotificationOptions,
|
||||
terminal: TerminalNotification,
|
||||
): Promise<void> {
|
||||
const config = getGlobalConfig()
|
||||
const channel = config.preferredNotifChannel
|
||||
|
||||
await executeNotificationHooks(notif)
|
||||
|
||||
const methodUsed = await sendToChannel(channel, notif, terminal)
|
||||
|
||||
logEvent('tengu_notification_method_used', {
|
||||
configured_channel:
|
||||
channel as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
method_used:
|
||||
methodUsed as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
term: env.terminal as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
}
|
||||
|
||||
const DEFAULT_TITLE = 'Claude Code'
|
||||
|
||||
async function sendToChannel(
|
||||
channel: string,
|
||||
opts: NotificationOptions,
|
||||
terminal: TerminalNotification,
|
||||
): Promise<string> {
|
||||
const title = opts.title || DEFAULT_TITLE
|
||||
|
||||
try {
|
||||
switch (channel) {
|
||||
case 'auto':
|
||||
return sendAuto(opts, terminal)
|
||||
case 'iterm2':
|
||||
terminal.notifyITerm2(opts)
|
||||
return 'iterm2'
|
||||
case 'iterm2_with_bell':
|
||||
terminal.notifyITerm2(opts)
|
||||
terminal.notifyBell()
|
||||
return 'iterm2_with_bell'
|
||||
case 'kitty':
|
||||
terminal.notifyKitty({ ...opts, title, id: generateKittyId() })
|
||||
return 'kitty'
|
||||
case 'ghostty':
|
||||
terminal.notifyGhostty({ ...opts, title })
|
||||
return 'ghostty'
|
||||
case 'terminal_bell':
|
||||
terminal.notifyBell()
|
||||
return 'terminal_bell'
|
||||
case 'notifications_disabled':
|
||||
return 'disabled'
|
||||
default:
|
||||
return 'none'
|
||||
}
|
||||
} catch {
|
||||
return 'error'
|
||||
}
|
||||
}
|
||||
|
||||
async function sendAuto(
|
||||
opts: NotificationOptions,
|
||||
terminal: TerminalNotification,
|
||||
): Promise<string> {
|
||||
const title = opts.title || DEFAULT_TITLE
|
||||
|
||||
switch (env.terminal) {
|
||||
case 'Apple_Terminal': {
|
||||
const bellDisabled = await isAppleTerminalBellDisabled()
|
||||
if (bellDisabled) {
|
||||
terminal.notifyBell()
|
||||
return 'terminal_bell'
|
||||
}
|
||||
return 'no_method_available'
|
||||
}
|
||||
case 'iTerm.app':
|
||||
terminal.notifyITerm2(opts)
|
||||
return 'iterm2'
|
||||
case 'kitty':
|
||||
terminal.notifyKitty({ ...opts, title, id: generateKittyId() })
|
||||
return 'kitty'
|
||||
case 'ghostty':
|
||||
terminal.notifyGhostty({ ...opts, title })
|
||||
return 'ghostty'
|
||||
default:
|
||||
return 'no_method_available'
|
||||
}
|
||||
}
|
||||
|
||||
function generateKittyId(): number {
|
||||
return Math.floor(Math.random() * 10000)
|
||||
}
|
||||
|
||||
async function isAppleTerminalBellDisabled(): Promise<boolean> {
|
||||
try {
|
||||
if (env.terminal !== 'Apple_Terminal') {
|
||||
return false
|
||||
}
|
||||
|
||||
const osascriptResult = await execFileNoThrow('osascript', [
|
||||
'-e',
|
||||
'tell application "Terminal" to name of current settings of front window',
|
||||
])
|
||||
const currentProfile = osascriptResult.stdout.trim()
|
||||
|
||||
if (!currentProfile) {
|
||||
return false
|
||||
}
|
||||
|
||||
const defaultsOutput = await execFileNoThrow('defaults', [
|
||||
'export',
|
||||
'com.apple.Terminal',
|
||||
'-',
|
||||
])
|
||||
|
||||
if (defaultsOutput.code !== 0) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Lazy-load plist (~280KB with xmlbuilder+@xmldom) — only hit on
|
||||
// Apple_Terminal with auto-channel, which is a small fraction of users.
|
||||
const plist = await import('plist')
|
||||
const parsed: Record<string, unknown> = plist.parse(defaultsOutput.stdout)
|
||||
const windowSettings = parsed?.['Window Settings'] as
|
||||
| Record<string, unknown>
|
||||
| undefined
|
||||
const profileSettings = windowSettings?.[currentProfile] as
|
||||
| Record<string, unknown>
|
||||
| undefined
|
||||
|
||||
if (!profileSettings) {
|
||||
return false
|
||||
}
|
||||
|
||||
return profileSettings.Bell === false
|
||||
} catch (error) {
|
||||
logError(error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
211
src/services/oauth/auth-code-listener.ts
Normal file
211
src/services/oauth/auth-code-listener.ts
Normal file
@@ -0,0 +1,211 @@
|
||||
import type { IncomingMessage, ServerResponse } from 'http'
|
||||
import { createServer, type Server } from 'http'
|
||||
import type { AddressInfo } from 'net'
|
||||
import { logEvent } from 'src/services/analytics/index.js'
|
||||
import { getOauthConfig } from '../../constants/oauth.js'
|
||||
import { logError } from '../../utils/log.js'
|
||||
import { shouldUseClaudeAIAuth } from './client.js'
|
||||
|
||||
/**
|
||||
* Temporary localhost HTTP server that listens for OAuth authorization code redirects.
|
||||
*
|
||||
* When the user authorizes in their browser, the OAuth provider redirects to:
|
||||
* http://localhost:[port]/callback?code=AUTH_CODE&state=STATE
|
||||
*
|
||||
* This server captures that redirect and extracts the auth code.
|
||||
* Note: This is NOT an OAuth server - it's just a redirect capture mechanism.
|
||||
*/
|
||||
export class AuthCodeListener {
|
||||
private localServer: Server
|
||||
private port: number = 0
|
||||
private promiseResolver: ((authorizationCode: string) => void) | null = null
|
||||
private promiseRejecter: ((error: Error) => void) | null = null
|
||||
private expectedState: string | null = null // State parameter for CSRF protection
|
||||
private pendingResponse: ServerResponse | null = null // Response object for final redirect
|
||||
private callbackPath: string // Configurable callback path
|
||||
|
||||
constructor(callbackPath: string = '/callback') {
|
||||
this.localServer = createServer()
|
||||
this.callbackPath = callbackPath
|
||||
}
|
||||
|
||||
/**
|
||||
* Starts listening on an OS-assigned port and returns the port number.
|
||||
* This avoids race conditions by keeping the server open until it's used.
|
||||
* @param port Optional specific port to use. If not provided, uses OS-assigned port.
|
||||
*/
|
||||
async start(port?: number): Promise<number> {
|
||||
return new Promise((resolve, reject) => {
|
||||
this.localServer.once('error', err => {
|
||||
reject(
|
||||
new Error(`Failed to start OAuth callback server: ${err.message}`),
|
||||
)
|
||||
})
|
||||
|
||||
// Listen on specified port or 0 to let the OS assign an available port
|
||||
this.localServer.listen(port ?? 0, 'localhost', () => {
|
||||
const address = this.localServer.address() as AddressInfo
|
||||
this.port = address.port
|
||||
resolve(this.port)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
getPort(): number {
|
||||
return this.port
|
||||
}
|
||||
|
||||
hasPendingResponse(): boolean {
|
||||
return this.pendingResponse !== null
|
||||
}
|
||||
|
||||
async waitForAuthorization(
|
||||
state: string,
|
||||
onReady: () => Promise<void>,
|
||||
): Promise<string> {
|
||||
return new Promise<string>((resolve, reject) => {
|
||||
this.promiseResolver = resolve
|
||||
this.promiseRejecter = reject
|
||||
this.expectedState = state
|
||||
this.startLocalListener(onReady)
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Completes the OAuth flow by redirecting the user's browser to a success page.
|
||||
* Different success pages are shown based on the granted scopes.
|
||||
* @param scopes The OAuth scopes that were granted
|
||||
* @param customHandler Optional custom handler to serve response instead of redirecting
|
||||
*/
|
||||
handleSuccessRedirect(
|
||||
scopes: string[],
|
||||
customHandler?: (res: ServerResponse, scopes: string[]) => void,
|
||||
): void {
|
||||
if (!this.pendingResponse) return
|
||||
|
||||
// If custom handler provided, use it instead of default redirect
|
||||
if (customHandler) {
|
||||
customHandler(this.pendingResponse, scopes)
|
||||
this.pendingResponse = null
|
||||
logEvent('tengu_oauth_automatic_redirect', { custom_handler: true })
|
||||
return
|
||||
}
|
||||
|
||||
// Default behavior: Choose success page based on granted permissions
|
||||
const successUrl = shouldUseClaudeAIAuth(scopes)
|
||||
? getOauthConfig().CLAUDEAI_SUCCESS_URL
|
||||
: getOauthConfig().CONSOLE_SUCCESS_URL
|
||||
|
||||
// Send browser to success page
|
||||
this.pendingResponse.writeHead(302, { Location: successUrl })
|
||||
this.pendingResponse.end()
|
||||
this.pendingResponse = null
|
||||
|
||||
logEvent('tengu_oauth_automatic_redirect', {})
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles error case by sending a redirect to the appropriate success page with an error indicator,
|
||||
* ensuring the browser flow is completed properly.
|
||||
*/
|
||||
handleErrorRedirect(): void {
|
||||
if (!this.pendingResponse) return
|
||||
|
||||
// TODO: swap to a different url once we have an error page
|
||||
const errorUrl = getOauthConfig().CLAUDEAI_SUCCESS_URL
|
||||
|
||||
// Send browser to error page
|
||||
this.pendingResponse.writeHead(302, { Location: errorUrl })
|
||||
this.pendingResponse.end()
|
||||
this.pendingResponse = null
|
||||
|
||||
logEvent('tengu_oauth_automatic_redirect_error', {})
|
||||
}
|
||||
|
||||
private startLocalListener(onReady: () => Promise<void>): void {
|
||||
// Server is already created and listening, just set up handlers
|
||||
this.localServer.on('request', this.handleRedirect.bind(this))
|
||||
this.localServer.on('error', this.handleError.bind(this))
|
||||
|
||||
// Server is already listening, so we can call onReady immediately
|
||||
void onReady()
|
||||
}
|
||||
|
||||
private handleRedirect(req: IncomingMessage, res: ServerResponse): void {
|
||||
const parsedUrl = new URL(
|
||||
req.url || '',
|
||||
`http://${req.headers.host || 'localhost'}`,
|
||||
)
|
||||
|
||||
if (parsedUrl.pathname !== this.callbackPath) {
|
||||
res.writeHead(404)
|
||||
res.end()
|
||||
return
|
||||
}
|
||||
|
||||
const authCode = parsedUrl.searchParams.get('code') ?? undefined
|
||||
const state = parsedUrl.searchParams.get('state') ?? undefined
|
||||
|
||||
this.validateAndRespond(authCode, state, res)
|
||||
}
|
||||
|
||||
private validateAndRespond(
|
||||
authCode: string | undefined,
|
||||
state: string | undefined,
|
||||
res: ServerResponse,
|
||||
): void {
|
||||
if (!authCode) {
|
||||
res.writeHead(400)
|
||||
res.end('Authorization code not found')
|
||||
this.reject(new Error('No authorization code received'))
|
||||
return
|
||||
}
|
||||
|
||||
if (state !== this.expectedState) {
|
||||
res.writeHead(400)
|
||||
res.end('Invalid state parameter')
|
||||
this.reject(new Error('Invalid state parameter'))
|
||||
return
|
||||
}
|
||||
|
||||
// Store the response for later redirect
|
||||
this.pendingResponse = res
|
||||
|
||||
this.resolve(authCode)
|
||||
}
|
||||
|
||||
private handleError(err: Error): void {
|
||||
logError(err)
|
||||
this.close()
|
||||
this.reject(err)
|
||||
}
|
||||
|
||||
private resolve(authorizationCode: string): void {
|
||||
if (this.promiseResolver) {
|
||||
this.promiseResolver(authorizationCode)
|
||||
this.promiseResolver = null
|
||||
this.promiseRejecter = null
|
||||
}
|
||||
}
|
||||
|
||||
private reject(error: Error): void {
|
||||
if (this.promiseRejecter) {
|
||||
this.promiseRejecter(error)
|
||||
this.promiseResolver = null
|
||||
this.promiseRejecter = null
|
||||
}
|
||||
}
|
||||
|
||||
close(): void {
|
||||
// If we have a pending response, send a redirect before closing
|
||||
if (this.pendingResponse) {
|
||||
this.handleErrorRedirect()
|
||||
}
|
||||
|
||||
if (this.localServer) {
|
||||
// Remove all listeners to prevent memory leaks
|
||||
this.localServer.removeAllListeners()
|
||||
this.localServer.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
566
src/services/oauth/client.ts
Normal file
566
src/services/oauth/client.ts
Normal file
@@ -0,0 +1,566 @@
|
||||
// OAuth client for handling authentication flows with Claude services
|
||||
import axios from 'axios'
|
||||
import {
|
||||
type AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
logEvent,
|
||||
} from 'src/services/analytics/index.js'
|
||||
import {
|
||||
ALL_OAUTH_SCOPES,
|
||||
CLAUDE_AI_INFERENCE_SCOPE,
|
||||
CLAUDE_AI_OAUTH_SCOPES,
|
||||
getOauthConfig,
|
||||
} from '../../constants/oauth.js'
|
||||
import {
|
||||
checkAndRefreshOAuthTokenIfNeeded,
|
||||
getClaudeAIOAuthTokens,
|
||||
hasProfileScope,
|
||||
isClaudeAISubscriber,
|
||||
saveApiKey,
|
||||
} from '../../utils/auth.js'
|
||||
import type { AccountInfo } from '../../utils/config.js'
|
||||
import { getGlobalConfig, saveGlobalConfig } from '../../utils/config.js'
|
||||
import { logForDebugging } from '../../utils/debug.js'
|
||||
import { getOauthProfileFromOauthToken } from './getOauthProfile.js'
|
||||
import type {
|
||||
BillingType,
|
||||
OAuthProfileResponse,
|
||||
OAuthTokenExchangeResponse,
|
||||
OAuthTokens,
|
||||
RateLimitTier,
|
||||
SubscriptionType,
|
||||
UserRolesResponse,
|
||||
} from './types.js'
|
||||
|
||||
/**
|
||||
* Check if the user has Claude.ai authentication scope
|
||||
* @private Only call this if you're OAuth / auth related code!
|
||||
*/
|
||||
export function shouldUseClaudeAIAuth(scopes: string[] | undefined): boolean {
|
||||
return Boolean(scopes?.includes(CLAUDE_AI_INFERENCE_SCOPE))
|
||||
}
|
||||
|
||||
export function parseScopes(scopeString?: string): string[] {
|
||||
return scopeString?.split(' ').filter(Boolean) ?? []
|
||||
}
|
||||
|
||||
export function buildAuthUrl({
|
||||
codeChallenge,
|
||||
state,
|
||||
port,
|
||||
isManual,
|
||||
loginWithClaudeAi,
|
||||
inferenceOnly,
|
||||
orgUUID,
|
||||
loginHint,
|
||||
loginMethod,
|
||||
}: {
|
||||
codeChallenge: string
|
||||
state: string
|
||||
port: number
|
||||
isManual: boolean
|
||||
loginWithClaudeAi?: boolean
|
||||
inferenceOnly?: boolean
|
||||
orgUUID?: string
|
||||
loginHint?: string
|
||||
loginMethod?: string
|
||||
}): string {
|
||||
const authUrlBase = loginWithClaudeAi
|
||||
? getOauthConfig().CLAUDE_AI_AUTHORIZE_URL
|
||||
: getOauthConfig().CONSOLE_AUTHORIZE_URL
|
||||
|
||||
const authUrl = new URL(authUrlBase)
|
||||
authUrl.searchParams.append('code', 'true') // this tells the login page to show Claude Max upsell
|
||||
authUrl.searchParams.append('client_id', getOauthConfig().CLIENT_ID)
|
||||
authUrl.searchParams.append('response_type', 'code')
|
||||
authUrl.searchParams.append(
|
||||
'redirect_uri',
|
||||
isManual
|
||||
? getOauthConfig().MANUAL_REDIRECT_URL
|
||||
: `http://localhost:${port}/callback`,
|
||||
)
|
||||
const scopesToUse = inferenceOnly
|
||||
? [CLAUDE_AI_INFERENCE_SCOPE] // Long-lived inference-only tokens
|
||||
: ALL_OAUTH_SCOPES
|
||||
authUrl.searchParams.append('scope', scopesToUse.join(' '))
|
||||
authUrl.searchParams.append('code_challenge', codeChallenge)
|
||||
authUrl.searchParams.append('code_challenge_method', 'S256')
|
||||
authUrl.searchParams.append('state', state)
|
||||
|
||||
// Add orgUUID as URL param if provided
|
||||
if (orgUUID) {
|
||||
authUrl.searchParams.append('orgUUID', orgUUID)
|
||||
}
|
||||
|
||||
// Pre-populate email on the login form (standard OIDC parameter)
|
||||
if (loginHint) {
|
||||
authUrl.searchParams.append('login_hint', loginHint)
|
||||
}
|
||||
|
||||
// Request a specific login method (e.g. 'sso', 'magic_link', 'google')
|
||||
if (loginMethod) {
|
||||
authUrl.searchParams.append('login_method', loginMethod)
|
||||
}
|
||||
|
||||
return authUrl.toString()
|
||||
}
|
||||
|
||||
export async function exchangeCodeForTokens(
|
||||
authorizationCode: string,
|
||||
state: string,
|
||||
codeVerifier: string,
|
||||
port: number,
|
||||
useManualRedirect: boolean = false,
|
||||
expiresIn?: number,
|
||||
): Promise<OAuthTokenExchangeResponse> {
|
||||
const requestBody: Record<string, string | number> = {
|
||||
grant_type: 'authorization_code',
|
||||
code: authorizationCode,
|
||||
redirect_uri: useManualRedirect
|
||||
? getOauthConfig().MANUAL_REDIRECT_URL
|
||||
: `http://localhost:${port}/callback`,
|
||||
client_id: getOauthConfig().CLIENT_ID,
|
||||
code_verifier: codeVerifier,
|
||||
state,
|
||||
}
|
||||
|
||||
if (expiresIn !== undefined) {
|
||||
requestBody.expires_in = expiresIn
|
||||
}
|
||||
|
||||
const response = await axios.post(getOauthConfig().TOKEN_URL, requestBody, {
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
timeout: 15000,
|
||||
})
|
||||
|
||||
if (response.status !== 200) {
|
||||
throw new Error(
|
||||
response.status === 401
|
||||
? 'Authentication failed: Invalid authorization code'
|
||||
: `Token exchange failed (${response.status}): ${response.statusText}`,
|
||||
)
|
||||
}
|
||||
logEvent('tengu_oauth_token_exchange_success', {})
|
||||
return response.data
|
||||
}
|
||||
|
||||
export async function refreshOAuthToken(
|
||||
refreshToken: string,
|
||||
{ scopes: requestedScopes }: { scopes?: string[] } = {},
|
||||
): Promise<OAuthTokens> {
|
||||
const requestBody = {
|
||||
grant_type: 'refresh_token',
|
||||
refresh_token: refreshToken,
|
||||
client_id: getOauthConfig().CLIENT_ID,
|
||||
// Request specific scopes, defaulting to the full Claude AI set. The
|
||||
// backend's refresh-token grant allows scope expansion beyond what the
|
||||
// initial authorize granted (see ALLOWED_SCOPE_EXPANSIONS), so this is
|
||||
// safe even for tokens issued before scopes were added to the app's
|
||||
// registered oauth_scope.
|
||||
scope: (requestedScopes?.length
|
||||
? requestedScopes
|
||||
: CLAUDE_AI_OAUTH_SCOPES
|
||||
).join(' '),
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await axios.post(getOauthConfig().TOKEN_URL, requestBody, {
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
timeout: 15000,
|
||||
})
|
||||
|
||||
if (response.status !== 200) {
|
||||
throw new Error(`Token refresh failed: ${response.statusText}`)
|
||||
}
|
||||
|
||||
const data = response.data as OAuthTokenExchangeResponse
|
||||
const {
|
||||
access_token: accessToken,
|
||||
refresh_token: newRefreshToken = refreshToken,
|
||||
expires_in: expiresIn,
|
||||
} = data
|
||||
|
||||
const expiresAt = Date.now() + expiresIn * 1000
|
||||
const scopes = parseScopes(data.scope)
|
||||
|
||||
logEvent('tengu_oauth_token_refresh_success', {})
|
||||
|
||||
// Skip the extra /api/oauth/profile round-trip when we already have both
|
||||
// the global-config profile fields AND the secure-storage subscription data.
|
||||
// Routine refreshes satisfy both, so we cut ~7M req/day fleet-wide.
|
||||
//
|
||||
// Checking secure storage (not just config) matters for the
|
||||
// CLAUDE_CODE_OAUTH_REFRESH_TOKEN re-login path: installOAuthTokens runs
|
||||
// performLogout() AFTER we return, wiping secure storage. If we returned
|
||||
// null for subscriptionType here, saveOAuthTokensIfNeeded would persist
|
||||
// null ?? (wiped) ?? null = null, and every future refresh would see the
|
||||
// config guard fields satisfied and skip again, permanently losing the
|
||||
// subscription type for paying users. By passing through existing values,
|
||||
// the re-login path writes cached ?? wiped ?? null = cached; and if secure
|
||||
// storage was already empty we fall through to the fetch.
|
||||
const config = getGlobalConfig()
|
||||
const existing = getClaudeAIOAuthTokens()
|
||||
const haveProfileAlready =
|
||||
config.oauthAccount?.billingType !== undefined &&
|
||||
config.oauthAccount?.accountCreatedAt !== undefined &&
|
||||
config.oauthAccount?.subscriptionCreatedAt !== undefined &&
|
||||
existing?.subscriptionType != null &&
|
||||
existing?.rateLimitTier != null
|
||||
|
||||
const profileInfo = haveProfileAlready
|
||||
? null
|
||||
: await fetchProfileInfo(accessToken)
|
||||
|
||||
// Update the stored properties if they have changed
|
||||
if (profileInfo && config.oauthAccount) {
|
||||
const updates: Partial<AccountInfo> = {}
|
||||
if (profileInfo.displayName !== undefined) {
|
||||
updates.displayName = profileInfo.displayName
|
||||
}
|
||||
if (typeof profileInfo.hasExtraUsageEnabled === 'boolean') {
|
||||
updates.hasExtraUsageEnabled = profileInfo.hasExtraUsageEnabled
|
||||
}
|
||||
if (profileInfo.billingType !== null) {
|
||||
updates.billingType = profileInfo.billingType
|
||||
}
|
||||
if (profileInfo.accountCreatedAt !== undefined) {
|
||||
updates.accountCreatedAt = profileInfo.accountCreatedAt
|
||||
}
|
||||
if (profileInfo.subscriptionCreatedAt !== undefined) {
|
||||
updates.subscriptionCreatedAt = profileInfo.subscriptionCreatedAt
|
||||
}
|
||||
if (Object.keys(updates).length > 0) {
|
||||
saveGlobalConfig(current => ({
|
||||
...current,
|
||||
oauthAccount: current.oauthAccount
|
||||
? { ...current.oauthAccount, ...updates }
|
||||
: current.oauthAccount,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
accessToken,
|
||||
refreshToken: newRefreshToken,
|
||||
expiresAt,
|
||||
scopes,
|
||||
subscriptionType:
|
||||
profileInfo?.subscriptionType ?? existing?.subscriptionType ?? null,
|
||||
rateLimitTier:
|
||||
profileInfo?.rateLimitTier ?? existing?.rateLimitTier ?? null,
|
||||
profile: profileInfo?.rawProfile,
|
||||
tokenAccount: data.account
|
||||
? {
|
||||
uuid: data.account.uuid,
|
||||
emailAddress: data.account.email_address,
|
||||
organizationUuid: data.organization?.uuid,
|
||||
}
|
||||
: undefined,
|
||||
}
|
||||
} catch (error) {
|
||||
const responseBody =
|
||||
axios.isAxiosError(error) && error.response?.data
|
||||
? JSON.stringify(error.response.data)
|
||||
: undefined
|
||||
logEvent('tengu_oauth_token_refresh_failure', {
|
||||
error: (error as Error)
|
||||
.message as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
...(responseBody && {
|
||||
responseBody:
|
||||
responseBody as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
}),
|
||||
})
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
export async function fetchAndStoreUserRoles(
|
||||
accessToken: string,
|
||||
): Promise<void> {
|
||||
const response = await axios.get(getOauthConfig().ROLES_URL, {
|
||||
headers: { Authorization: `Bearer ${accessToken}` },
|
||||
})
|
||||
|
||||
if (response.status !== 200) {
|
||||
throw new Error(`Failed to fetch user roles: ${response.statusText}`)
|
||||
}
|
||||
const data = response.data as UserRolesResponse
|
||||
const config = getGlobalConfig()
|
||||
|
||||
if (!config.oauthAccount) {
|
||||
throw new Error('OAuth account information not found in config')
|
||||
}
|
||||
|
||||
saveGlobalConfig(current => ({
|
||||
...current,
|
||||
oauthAccount: current.oauthAccount
|
||||
? {
|
||||
...current.oauthAccount,
|
||||
organizationRole: data.organization_role,
|
||||
workspaceRole: data.workspace_role,
|
||||
organizationName: data.organization_name,
|
||||
}
|
||||
: current.oauthAccount,
|
||||
}))
|
||||
|
||||
logEvent('tengu_oauth_roles_stored', {
|
||||
org_role:
|
||||
data.organization_role as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
}
|
||||
|
||||
export async function createAndStoreApiKey(
|
||||
accessToken: string,
|
||||
): Promise<string | null> {
|
||||
try {
|
||||
const response = await axios.post(getOauthConfig().API_KEY_URL, null, {
|
||||
headers: { Authorization: `Bearer ${accessToken}` },
|
||||
})
|
||||
|
||||
const apiKey = response.data?.raw_key
|
||||
if (apiKey) {
|
||||
await saveApiKey(apiKey)
|
||||
logEvent('tengu_oauth_api_key', {
|
||||
status:
|
||||
'success' as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
statusCode: response.status,
|
||||
})
|
||||
return apiKey
|
||||
}
|
||||
return null
|
||||
} catch (error) {
|
||||
logEvent('tengu_oauth_api_key', {
|
||||
status:
|
||||
'failure' as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
error: (error instanceof Error
|
||||
? error.message
|
||||
: String(
|
||||
error,
|
||||
)) as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
export function isOAuthTokenExpired(expiresAt: number | null): boolean {
|
||||
if (expiresAt === null) {
|
||||
return false
|
||||
}
|
||||
|
||||
const bufferTime = 5 * 60 * 1000
|
||||
const now = Date.now()
|
||||
const expiresWithBuffer = now + bufferTime
|
||||
return expiresWithBuffer >= expiresAt
|
||||
}
|
||||
|
||||
export async function fetchProfileInfo(accessToken: string): Promise<{
|
||||
subscriptionType: SubscriptionType | null
|
||||
displayName?: string
|
||||
rateLimitTier: RateLimitTier | null
|
||||
hasExtraUsageEnabled: boolean | null
|
||||
billingType: BillingType | null
|
||||
accountCreatedAt?: string
|
||||
subscriptionCreatedAt?: string
|
||||
rawProfile?: OAuthProfileResponse
|
||||
}> {
|
||||
const profile = await getOauthProfileFromOauthToken(accessToken)
|
||||
const orgType = profile?.organization?.organization_type
|
||||
|
||||
// Reuse the logic from fetchSubscriptionType
|
||||
let subscriptionType: SubscriptionType | null = null
|
||||
switch (orgType) {
|
||||
case 'claude_max':
|
||||
subscriptionType = 'max'
|
||||
break
|
||||
case 'claude_pro':
|
||||
subscriptionType = 'pro'
|
||||
break
|
||||
case 'claude_enterprise':
|
||||
subscriptionType = 'enterprise'
|
||||
break
|
||||
case 'claude_team':
|
||||
subscriptionType = 'team'
|
||||
break
|
||||
default:
|
||||
// Return null for unknown organization types
|
||||
subscriptionType = null
|
||||
break
|
||||
}
|
||||
|
||||
const result: {
|
||||
subscriptionType: SubscriptionType | null
|
||||
displayName?: string
|
||||
rateLimitTier: RateLimitTier | null
|
||||
hasExtraUsageEnabled: boolean | null
|
||||
billingType: BillingType | null
|
||||
accountCreatedAt?: string
|
||||
subscriptionCreatedAt?: string
|
||||
} = {
|
||||
subscriptionType,
|
||||
rateLimitTier: profile?.organization?.rate_limit_tier ?? null,
|
||||
hasExtraUsageEnabled:
|
||||
profile?.organization?.has_extra_usage_enabled ?? null,
|
||||
billingType: profile?.organization?.billing_type ?? null,
|
||||
}
|
||||
|
||||
if (profile?.account?.display_name) {
|
||||
result.displayName = profile.account.display_name
|
||||
}
|
||||
|
||||
if (profile?.account?.created_at) {
|
||||
result.accountCreatedAt = profile.account.created_at
|
||||
}
|
||||
|
||||
if (profile?.organization?.subscription_created_at) {
|
||||
result.subscriptionCreatedAt = profile.organization.subscription_created_at
|
||||
}
|
||||
|
||||
logEvent('tengu_oauth_profile_fetch_success', {})
|
||||
|
||||
return { ...result, rawProfile: profile }
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the organization UUID from the OAuth access token
|
||||
* @returns The organization UUID or null if not authenticated
|
||||
*/
|
||||
export async function getOrganizationUUID(): Promise<string | null> {
|
||||
// Check global config first to avoid unnecessary API call
|
||||
const globalConfig = getGlobalConfig()
|
||||
const orgUUID = globalConfig.oauthAccount?.organizationUuid
|
||||
if (orgUUID) {
|
||||
return orgUUID
|
||||
}
|
||||
|
||||
// Fall back to fetching from profile (requires user:profile scope)
|
||||
const accessToken = getClaudeAIOAuthTokens()?.accessToken
|
||||
if (accessToken === undefined || !hasProfileScope()) {
|
||||
return null
|
||||
}
|
||||
const profile = await getOauthProfileFromOauthToken(accessToken)
|
||||
const profileOrgUUID = profile?.organization?.uuid
|
||||
if (!profileOrgUUID) {
|
||||
return null
|
||||
}
|
||||
return profileOrgUUID
|
||||
}
|
||||
|
||||
/**
|
||||
* Populate the OAuth account info if it has not already been cached in config.
|
||||
* @returns Whether or not the oauth account info was populated.
|
||||
*/
|
||||
export async function populateOAuthAccountInfoIfNeeded(): Promise<boolean> {
|
||||
// Check env vars first (synchronous, no network call needed).
|
||||
// SDK callers like Cowork can provide account info directly, which also
|
||||
// eliminates the race condition where early telemetry events lack account info.
|
||||
// NB: If/when adding additional SDK-relevant functionality requiring _other_ OAuth account properties,
|
||||
// please reach out to #proj-cowork so the team can add additional env var fallbacks.
|
||||
const envAccountUuid = process.env.CLAUDE_CODE_ACCOUNT_UUID
|
||||
const envUserEmail = process.env.CLAUDE_CODE_USER_EMAIL
|
||||
const envOrganizationUuid = process.env.CLAUDE_CODE_ORGANIZATION_UUID
|
||||
const hasEnvVars = Boolean(
|
||||
envAccountUuid && envUserEmail && envOrganizationUuid,
|
||||
)
|
||||
if (envAccountUuid && envUserEmail && envOrganizationUuid) {
|
||||
if (!getGlobalConfig().oauthAccount) {
|
||||
storeOAuthAccountInfo({
|
||||
accountUuid: envAccountUuid,
|
||||
emailAddress: envUserEmail,
|
||||
organizationUuid: envOrganizationUuid,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for any in-flight token refresh to complete first, since
|
||||
// refreshOAuthToken already fetches and stores profile info
|
||||
await checkAndRefreshOAuthTokenIfNeeded()
|
||||
|
||||
const config = getGlobalConfig()
|
||||
if (
|
||||
(config.oauthAccount &&
|
||||
config.oauthAccount.billingType !== undefined &&
|
||||
config.oauthAccount.accountCreatedAt !== undefined &&
|
||||
config.oauthAccount.subscriptionCreatedAt !== undefined) ||
|
||||
!isClaudeAISubscriber() ||
|
||||
!hasProfileScope()
|
||||
) {
|
||||
return false
|
||||
}
|
||||
|
||||
const tokens = getClaudeAIOAuthTokens()
|
||||
if (tokens?.accessToken) {
|
||||
const profile = await getOauthProfileFromOauthToken(tokens.accessToken)
|
||||
if (profile) {
|
||||
if (hasEnvVars) {
|
||||
logForDebugging(
|
||||
'OAuth profile fetch succeeded, overriding env var account info',
|
||||
{ level: 'info' },
|
||||
)
|
||||
}
|
||||
storeOAuthAccountInfo({
|
||||
accountUuid: profile.account.uuid,
|
||||
emailAddress: profile.account.email,
|
||||
organizationUuid: profile.organization.uuid,
|
||||
displayName: profile.account.display_name || undefined,
|
||||
hasExtraUsageEnabled:
|
||||
profile.organization.has_extra_usage_enabled ?? false,
|
||||
billingType: profile.organization.billing_type ?? undefined,
|
||||
accountCreatedAt: profile.account.created_at,
|
||||
subscriptionCreatedAt:
|
||||
profile.organization.subscription_created_at ?? undefined,
|
||||
})
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
export function storeOAuthAccountInfo({
|
||||
accountUuid,
|
||||
emailAddress,
|
||||
organizationUuid,
|
||||
displayName,
|
||||
hasExtraUsageEnabled,
|
||||
billingType,
|
||||
accountCreatedAt,
|
||||
subscriptionCreatedAt,
|
||||
}: {
|
||||
accountUuid: string
|
||||
emailAddress: string
|
||||
organizationUuid: string | undefined
|
||||
displayName?: string
|
||||
hasExtraUsageEnabled?: boolean
|
||||
billingType?: BillingType
|
||||
accountCreatedAt?: string
|
||||
subscriptionCreatedAt?: string
|
||||
}): void {
|
||||
const accountInfo: AccountInfo = {
|
||||
accountUuid,
|
||||
emailAddress,
|
||||
organizationUuid,
|
||||
hasExtraUsageEnabled,
|
||||
billingType,
|
||||
accountCreatedAt,
|
||||
subscriptionCreatedAt,
|
||||
}
|
||||
if (displayName) {
|
||||
accountInfo.displayName = displayName
|
||||
}
|
||||
saveGlobalConfig(current => {
|
||||
// For oauthAccount we need to compare content since it's an object
|
||||
if (
|
||||
current.oauthAccount?.accountUuid === accountInfo.accountUuid &&
|
||||
current.oauthAccount?.emailAddress === accountInfo.emailAddress &&
|
||||
current.oauthAccount?.organizationUuid === accountInfo.organizationUuid &&
|
||||
current.oauthAccount?.displayName === accountInfo.displayName &&
|
||||
current.oauthAccount?.hasExtraUsageEnabled ===
|
||||
accountInfo.hasExtraUsageEnabled &&
|
||||
current.oauthAccount?.billingType === accountInfo.billingType &&
|
||||
current.oauthAccount?.accountCreatedAt === accountInfo.accountCreatedAt &&
|
||||
current.oauthAccount?.subscriptionCreatedAt ===
|
||||
accountInfo.subscriptionCreatedAt
|
||||
) {
|
||||
return current
|
||||
}
|
||||
return { ...current, oauthAccount: accountInfo }
|
||||
})
|
||||
}
|
||||
23
src/services/oauth/crypto.ts
Normal file
23
src/services/oauth/crypto.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
import { createHash, randomBytes } from 'crypto'
|
||||
|
||||
function base64URLEncode(buffer: Buffer): string {
|
||||
return buffer
|
||||
.toString('base64')
|
||||
.replace(/\+/g, '-')
|
||||
.replace(/\//g, '_')
|
||||
.replace(/=/g, '')
|
||||
}
|
||||
|
||||
export function generateCodeVerifier(): string {
|
||||
return base64URLEncode(randomBytes(32))
|
||||
}
|
||||
|
||||
export function generateCodeChallenge(verifier: string): string {
|
||||
const hash = createHash('sha256')
|
||||
hash.update(verifier)
|
||||
return base64URLEncode(hash.digest())
|
||||
}
|
||||
|
||||
export function generateState(): string {
|
||||
return base64URLEncode(randomBytes(32))
|
||||
}
|
||||
53
src/services/oauth/getOauthProfile.ts
Normal file
53
src/services/oauth/getOauthProfile.ts
Normal file
@@ -0,0 +1,53 @@
|
||||
import axios from 'axios'
|
||||
import { getOauthConfig, OAUTH_BETA_HEADER } from 'src/constants/oauth.js'
|
||||
import type { OAuthProfileResponse } from 'src/services/oauth/types.js'
|
||||
import { getAnthropicApiKey } from 'src/utils/auth.js'
|
||||
import { getGlobalConfig } from 'src/utils/config.js'
|
||||
import { logError } from 'src/utils/log.js'
|
||||
export async function getOauthProfileFromApiKey(): Promise<
|
||||
OAuthProfileResponse | undefined
|
||||
> {
|
||||
// Assumes interactive session
|
||||
const config = getGlobalConfig()
|
||||
const accountUuid = config.oauthAccount?.accountUuid
|
||||
const apiKey = getAnthropicApiKey()
|
||||
|
||||
// Need both account UUID and API key to check
|
||||
if (!accountUuid || !apiKey) {
|
||||
return
|
||||
}
|
||||
const endpoint = `${getOauthConfig().BASE_API_URL}/api/claude_cli_profile`
|
||||
try {
|
||||
const response = await axios.get<OAuthProfileResponse>(endpoint, {
|
||||
headers: {
|
||||
'x-api-key': apiKey,
|
||||
'anthropic-beta': OAUTH_BETA_HEADER,
|
||||
},
|
||||
params: {
|
||||
account_uuid: accountUuid,
|
||||
},
|
||||
timeout: 10000,
|
||||
})
|
||||
return response.data
|
||||
} catch (error) {
|
||||
logError(error as Error)
|
||||
}
|
||||
}
|
||||
|
||||
export async function getOauthProfileFromOauthToken(
|
||||
accessToken: string,
|
||||
): Promise<OAuthProfileResponse | undefined> {
|
||||
const endpoint = `${getOauthConfig().BASE_API_URL}/api/oauth/profile`
|
||||
try {
|
||||
const response = await axios.get<OAuthProfileResponse>(endpoint, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
timeout: 10000,
|
||||
})
|
||||
return response.data
|
||||
} catch (error) {
|
||||
logError(error as Error)
|
||||
}
|
||||
}
|
||||
198
src/services/oauth/index.ts
Normal file
198
src/services/oauth/index.ts
Normal file
@@ -0,0 +1,198 @@
|
||||
import { logEvent } from 'src/services/analytics/index.js'
|
||||
import { openBrowser } from '../../utils/browser.js'
|
||||
import { AuthCodeListener } from './auth-code-listener.js'
|
||||
import * as client from './client.js'
|
||||
import * as crypto from './crypto.js'
|
||||
import type {
|
||||
OAuthProfileResponse,
|
||||
OAuthTokenExchangeResponse,
|
||||
OAuthTokens,
|
||||
RateLimitTier,
|
||||
SubscriptionType,
|
||||
} from './types.js'
|
||||
|
||||
/**
|
||||
* OAuth service that handles the OAuth 2.0 authorization code flow with PKCE.
|
||||
*
|
||||
* Supports two ways to get authorization codes:
|
||||
* 1. Automatic: Opens browser, redirects to localhost where we capture the code
|
||||
* 2. Manual: User manually copies and pastes the code (used in non-browser environments)
|
||||
*/
|
||||
export class OAuthService {
|
||||
private codeVerifier: string
|
||||
private authCodeListener: AuthCodeListener | null = null
|
||||
private port: number | null = null
|
||||
private manualAuthCodeResolver: ((authorizationCode: string) => void) | null =
|
||||
null
|
||||
|
||||
constructor() {
|
||||
this.codeVerifier = crypto.generateCodeVerifier()
|
||||
}
|
||||
|
||||
async startOAuthFlow(
|
||||
authURLHandler: (url: string, automaticUrl?: string) => Promise<void>,
|
||||
options?: {
|
||||
loginWithClaudeAi?: boolean
|
||||
inferenceOnly?: boolean
|
||||
expiresIn?: number
|
||||
orgUUID?: string
|
||||
loginHint?: string
|
||||
loginMethod?: string
|
||||
/**
|
||||
* Don't call openBrowser(). Caller takes both URLs via authURLHandler
|
||||
* and decides how/where to open them. Used by the SDK control protocol
|
||||
* (claude_authenticate) where the SDK client owns the user's display,
|
||||
* not this process.
|
||||
*/
|
||||
skipBrowserOpen?: boolean
|
||||
},
|
||||
): Promise<OAuthTokens> {
|
||||
// Create OAuth callback listener and start it
|
||||
this.authCodeListener = new AuthCodeListener()
|
||||
this.port = await this.authCodeListener.start()
|
||||
|
||||
// Generate PKCE values and state
|
||||
const codeChallenge = crypto.generateCodeChallenge(this.codeVerifier)
|
||||
const state = crypto.generateState()
|
||||
|
||||
// Build auth URLs for both automatic and manual flows
|
||||
const opts = {
|
||||
codeChallenge,
|
||||
state,
|
||||
port: this.port,
|
||||
loginWithClaudeAi: options?.loginWithClaudeAi,
|
||||
inferenceOnly: options?.inferenceOnly,
|
||||
orgUUID: options?.orgUUID,
|
||||
loginHint: options?.loginHint,
|
||||
loginMethod: options?.loginMethod,
|
||||
}
|
||||
const manualFlowUrl = client.buildAuthUrl({ ...opts, isManual: true })
|
||||
const automaticFlowUrl = client.buildAuthUrl({ ...opts, isManual: false })
|
||||
|
||||
// Wait for either automatic or manual auth code
|
||||
const authorizationCode = await this.waitForAuthorizationCode(
|
||||
state,
|
||||
async () => {
|
||||
if (options?.skipBrowserOpen) {
|
||||
// Hand both URLs to the caller. The automatic one still works
|
||||
// if the caller opens it on the same host (localhost listener
|
||||
// is running); the manual one works from anywhere.
|
||||
await authURLHandler(manualFlowUrl, automaticFlowUrl)
|
||||
} else {
|
||||
await authURLHandler(manualFlowUrl) // Show manual option to user
|
||||
await openBrowser(automaticFlowUrl) // Try automatic flow
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
// Check if the automatic flow is still active (has a pending response)
|
||||
const isAutomaticFlow = this.authCodeListener?.hasPendingResponse() ?? false
|
||||
logEvent('tengu_oauth_auth_code_received', { automatic: isAutomaticFlow })
|
||||
|
||||
try {
|
||||
// Exchange authorization code for tokens
|
||||
const tokenResponse = await client.exchangeCodeForTokens(
|
||||
authorizationCode,
|
||||
state,
|
||||
this.codeVerifier,
|
||||
this.port!,
|
||||
!isAutomaticFlow, // Pass isManual=true if it's NOT automatic flow
|
||||
options?.expiresIn,
|
||||
)
|
||||
|
||||
// Fetch profile info (subscription type and rate limit tier) for the
|
||||
// returned OAuthTokens. Logout and account storage are handled by the
|
||||
// caller (installOAuthTokens in auth.ts).
|
||||
const profileInfo = await client.fetchProfileInfo(
|
||||
tokenResponse.access_token,
|
||||
)
|
||||
|
||||
// Handle success redirect for automatic flow
|
||||
if (isAutomaticFlow) {
|
||||
const scopes = client.parseScopes(tokenResponse.scope)
|
||||
this.authCodeListener?.handleSuccessRedirect(scopes)
|
||||
}
|
||||
|
||||
return this.formatTokens(
|
||||
tokenResponse,
|
||||
profileInfo.subscriptionType,
|
||||
profileInfo.rateLimitTier,
|
||||
profileInfo.rawProfile,
|
||||
)
|
||||
} catch (error) {
|
||||
// If we have a pending response, send an error redirect before closing
|
||||
if (isAutomaticFlow) {
|
||||
this.authCodeListener?.handleErrorRedirect()
|
||||
}
|
||||
throw error
|
||||
} finally {
|
||||
// Always cleanup
|
||||
this.authCodeListener?.close()
|
||||
}
|
||||
}
|
||||
|
||||
private async waitForAuthorizationCode(
|
||||
state: string,
|
||||
onReady: () => Promise<void>,
|
||||
): Promise<string> {
|
||||
return new Promise((resolve, reject) => {
|
||||
// Set up manual auth code resolver
|
||||
this.manualAuthCodeResolver = resolve
|
||||
|
||||
// Start automatic flow
|
||||
this.authCodeListener
|
||||
?.waitForAuthorization(state, onReady)
|
||||
.then(authorizationCode => {
|
||||
this.manualAuthCodeResolver = null
|
||||
resolve(authorizationCode)
|
||||
})
|
||||
.catch(error => {
|
||||
this.manualAuthCodeResolver = null
|
||||
reject(error)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Handle manual flow callback when user pastes the auth code
|
||||
handleManualAuthCodeInput(params: {
|
||||
authorizationCode: string
|
||||
state: string
|
||||
}): void {
|
||||
if (this.manualAuthCodeResolver) {
|
||||
this.manualAuthCodeResolver(params.authorizationCode)
|
||||
this.manualAuthCodeResolver = null
|
||||
// Close the auth code listener since manual input was used
|
||||
this.authCodeListener?.close()
|
||||
}
|
||||
}
|
||||
|
||||
private formatTokens(
|
||||
response: OAuthTokenExchangeResponse,
|
||||
subscriptionType: SubscriptionType | null,
|
||||
rateLimitTier: RateLimitTier | null,
|
||||
profile?: OAuthProfileResponse,
|
||||
): OAuthTokens {
|
||||
return {
|
||||
accessToken: response.access_token,
|
||||
refreshToken: response.refresh_token,
|
||||
expiresAt: Date.now() + response.expires_in * 1000,
|
||||
scopes: client.parseScopes(response.scope),
|
||||
subscriptionType,
|
||||
rateLimitTier,
|
||||
profile,
|
||||
tokenAccount: response.account
|
||||
? {
|
||||
uuid: response.account.uuid,
|
||||
emailAddress: response.account.email_address,
|
||||
organizationUuid: response.organization?.uuid,
|
||||
}
|
||||
: undefined,
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up any resources (like the local server)
|
||||
cleanup(): void {
|
||||
this.authCodeListener?.close()
|
||||
this.manualAuthCodeResolver = null
|
||||
}
|
||||
}
|
||||
184
src/services/plugins/PluginInstallationManager.ts
Normal file
184
src/services/plugins/PluginInstallationManager.ts
Normal file
@@ -0,0 +1,184 @@
|
||||
/**
|
||||
* Background plugin and marketplace installation manager
|
||||
*
|
||||
* This module handles automatic installation of plugins and marketplaces
|
||||
* from trusted sources (repository and user settings) without blocking startup.
|
||||
*/
|
||||
|
||||
import type { AppState } from '../../state/AppState.js'
|
||||
import { logForDebugging } from '../../utils/debug.js'
|
||||
import { logForDiagnosticsNoPII } from '../../utils/diagLogs.js'
|
||||
import { logError } from '../../utils/log.js'
|
||||
import {
|
||||
clearMarketplacesCache,
|
||||
getDeclaredMarketplaces,
|
||||
loadKnownMarketplacesConfig,
|
||||
} from '../../utils/plugins/marketplaceManager.js'
|
||||
import { clearPluginCache } from '../../utils/plugins/pluginLoader.js'
|
||||
import {
|
||||
diffMarketplaces,
|
||||
reconcileMarketplaces,
|
||||
} from '../../utils/plugins/reconciler.js'
|
||||
import { refreshActivePlugins } from '../../utils/plugins/refresh.js'
|
||||
import { logEvent } from '../analytics/index.js'
|
||||
|
||||
type SetAppState = (f: (prevState: AppState) => AppState) => void
|
||||
|
||||
/**
|
||||
* Update marketplace installation status in app state
|
||||
*/
|
||||
function updateMarketplaceStatus(
|
||||
setAppState: SetAppState,
|
||||
name: string,
|
||||
status: 'pending' | 'installing' | 'installed' | 'failed',
|
||||
error?: string,
|
||||
): void {
|
||||
setAppState(prevState => ({
|
||||
...prevState,
|
||||
plugins: {
|
||||
...prevState.plugins,
|
||||
installationStatus: {
|
||||
...prevState.plugins.installationStatus,
|
||||
marketplaces: prevState.plugins.installationStatus.marketplaces.map(
|
||||
m => (m.name === name ? { ...m, status, error } : m),
|
||||
),
|
||||
},
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform background plugin startup checks and installations.
|
||||
*
|
||||
* This is a thin wrapper around reconcileMarketplaces() that maps onProgress
|
||||
* events to AppState updates for the REPL UI. After marketplaces are
|
||||
* reconciled:
|
||||
* - New installs → auto-refresh plugins (fixes "plugin-not-found" errors
|
||||
* from the initial cache-only load on fresh homespace/cleared cache)
|
||||
* - Updates only → set needsRefresh, show notification for /reload-plugins
|
||||
*/
|
||||
export async function performBackgroundPluginInstallations(
|
||||
setAppState: SetAppState,
|
||||
): Promise<void> {
|
||||
logForDebugging('performBackgroundPluginInstallations called')
|
||||
|
||||
try {
|
||||
// Compute diff upfront for initial UI status (pending spinners)
|
||||
const declared = getDeclaredMarketplaces()
|
||||
const materialized = await loadKnownMarketplacesConfig().catch(() => ({}))
|
||||
const diff = diffMarketplaces(declared, materialized)
|
||||
|
||||
const pendingNames = [
|
||||
...diff.missing,
|
||||
...diff.sourceChanged.map(c => c.name),
|
||||
]
|
||||
|
||||
// Initialize AppState with pending status. No per-plugin pending status —
|
||||
// plugin load is fast (cache hit or local copy); marketplace clone is the
|
||||
// slow part worth showing progress for.
|
||||
setAppState(prev => ({
|
||||
...prev,
|
||||
plugins: {
|
||||
...prev.plugins,
|
||||
installationStatus: {
|
||||
marketplaces: pendingNames.map(name => ({
|
||||
name,
|
||||
status: 'pending' as const,
|
||||
})),
|
||||
plugins: [],
|
||||
},
|
||||
},
|
||||
}))
|
||||
|
||||
if (pendingNames.length === 0) {
|
||||
return
|
||||
}
|
||||
|
||||
logForDebugging(
|
||||
`Installing ${pendingNames.length} marketplace(s) in background`,
|
||||
)
|
||||
|
||||
const result = await reconcileMarketplaces({
|
||||
onProgress: event => {
|
||||
switch (event.type) {
|
||||
case 'installing':
|
||||
updateMarketplaceStatus(setAppState, event.name, 'installing')
|
||||
break
|
||||
case 'installed':
|
||||
updateMarketplaceStatus(setAppState, event.name, 'installed')
|
||||
break
|
||||
case 'failed':
|
||||
updateMarketplaceStatus(
|
||||
setAppState,
|
||||
event.name,
|
||||
'failed',
|
||||
event.error,
|
||||
)
|
||||
break
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
const metrics = {
|
||||
installed_count: result.installed.length,
|
||||
updated_count: result.updated.length,
|
||||
failed_count: result.failed.length,
|
||||
up_to_date_count: result.upToDate.length,
|
||||
}
|
||||
logEvent('tengu_marketplace_background_install', metrics)
|
||||
logForDiagnosticsNoPII(
|
||||
'info',
|
||||
'tengu_marketplace_background_install',
|
||||
metrics,
|
||||
)
|
||||
|
||||
if (result.installed.length > 0) {
|
||||
// New marketplaces were installed — auto-refresh plugins. This fixes
|
||||
// "Plugin not found in marketplace" errors from the initial cache-only
|
||||
// load (e.g., fresh homespace where marketplace cache was empty).
|
||||
// refreshActivePlugins clears all caches, reloads plugins, and bumps
|
||||
// pluginReconnectKey so MCP connections are re-established.
|
||||
clearMarketplacesCache()
|
||||
logForDebugging(
|
||||
`Auto-refreshing plugins after ${result.installed.length} new marketplace(s) installed`,
|
||||
)
|
||||
try {
|
||||
await refreshActivePlugins(setAppState)
|
||||
} catch (refreshError) {
|
||||
// If auto-refresh fails, fall back to needsRefresh notification so
|
||||
// the user can manually run /reload-plugins to recover.
|
||||
logError(refreshError)
|
||||
logForDebugging(
|
||||
`Auto-refresh failed, falling back to needsRefresh: ${refreshError}`,
|
||||
{ level: 'warn' },
|
||||
)
|
||||
clearPluginCache(
|
||||
'performBackgroundPluginInstallations: auto-refresh failed',
|
||||
)
|
||||
setAppState(prev => {
|
||||
if (prev.plugins.needsRefresh) return prev
|
||||
return {
|
||||
...prev,
|
||||
plugins: { ...prev.plugins, needsRefresh: true },
|
||||
}
|
||||
})
|
||||
}
|
||||
} else if (result.updated.length > 0) {
|
||||
// Existing marketplaces updated — notify user to run /reload-plugins.
|
||||
// Updates are less urgent and the user should choose when to apply them.
|
||||
clearMarketplacesCache()
|
||||
clearPluginCache(
|
||||
'performBackgroundPluginInstallations: marketplaces reconciled',
|
||||
)
|
||||
setAppState(prev => {
|
||||
if (prev.plugins.needsRefresh) return prev
|
||||
return {
|
||||
...prev,
|
||||
plugins: { ...prev.plugins, needsRefresh: true },
|
||||
}
|
||||
})
|
||||
}
|
||||
} catch (error) {
|
||||
logError(error)
|
||||
}
|
||||
}
|
||||
344
src/services/plugins/pluginCliCommands.ts
Normal file
344
src/services/plugins/pluginCliCommands.ts
Normal file
@@ -0,0 +1,344 @@
|
||||
/**
|
||||
* CLI command wrappers for plugin operations
|
||||
*
|
||||
* This module provides thin wrappers around the core plugin operations
|
||||
* that handle CLI-specific concerns like console output and process exit.
|
||||
*
|
||||
* For the core operations (without CLI side effects), see pluginOperations.ts
|
||||
*/
|
||||
import figures from 'figures'
|
||||
import { errorMessage } from '../../utils/errors.js'
|
||||
import { gracefulShutdown } from '../../utils/gracefulShutdown.js'
|
||||
import { logError } from '../../utils/log.js'
|
||||
import { getManagedPluginNames } from '../../utils/plugins/managedPlugins.js'
|
||||
import { parsePluginIdentifier } from '../../utils/plugins/pluginIdentifier.js'
|
||||
import type { PluginScope } from '../../utils/plugins/schemas.js'
|
||||
import { writeToStdout } from '../../utils/process.js'
|
||||
import {
|
||||
buildPluginTelemetryFields,
|
||||
classifyPluginCommandError,
|
||||
} from '../../utils/telemetry/pluginTelemetry.js'
|
||||
import {
|
||||
type AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
type AnalyticsMetadata_I_VERIFIED_THIS_IS_PII_TAGGED,
|
||||
logEvent,
|
||||
} from '../analytics/index.js'
|
||||
import {
|
||||
disableAllPluginsOp,
|
||||
disablePluginOp,
|
||||
enablePluginOp,
|
||||
type InstallableScope,
|
||||
installPluginOp,
|
||||
uninstallPluginOp,
|
||||
updatePluginOp,
|
||||
VALID_INSTALLABLE_SCOPES,
|
||||
VALID_UPDATE_SCOPES,
|
||||
} from './pluginOperations.js'
|
||||
|
||||
export { VALID_INSTALLABLE_SCOPES, VALID_UPDATE_SCOPES }
|
||||
|
||||
type PluginCliCommand =
|
||||
| 'install'
|
||||
| 'uninstall'
|
||||
| 'enable'
|
||||
| 'disable'
|
||||
| 'disable-all'
|
||||
| 'update'
|
||||
|
||||
/**
|
||||
* Generic error handler for plugin CLI commands. Emits
|
||||
* tengu_plugin_command_failed before exit so dashboards can compute a
|
||||
* success rate against the corresponding success events.
|
||||
*/
|
||||
function handlePluginCommandError(
|
||||
error: unknown,
|
||||
command: PluginCliCommand,
|
||||
plugin?: string,
|
||||
): never {
|
||||
logError(error)
|
||||
const operation = plugin
|
||||
? `${command} plugin "${plugin}"`
|
||||
: command === 'disable-all'
|
||||
? 'disable all plugins'
|
||||
: `${command} plugins`
|
||||
// biome-ignore lint/suspicious/noConsole:: intentional console output
|
||||
console.error(
|
||||
`${figures.cross} Failed to ${operation}: ${errorMessage(error)}`,
|
||||
)
|
||||
const telemetryFields = plugin
|
||||
? (() => {
|
||||
const { name, marketplace } = parsePluginIdentifier(plugin)
|
||||
return {
|
||||
_PROTO_plugin_name:
|
||||
name as AnalyticsMetadata_I_VERIFIED_THIS_IS_PII_TAGGED,
|
||||
...(marketplace && {
|
||||
_PROTO_marketplace_name:
|
||||
marketplace as AnalyticsMetadata_I_VERIFIED_THIS_IS_PII_TAGGED,
|
||||
}),
|
||||
...buildPluginTelemetryFields(
|
||||
name,
|
||||
marketplace,
|
||||
getManagedPluginNames(),
|
||||
),
|
||||
}
|
||||
})()
|
||||
: {}
|
||||
logEvent('tengu_plugin_command_failed', {
|
||||
command:
|
||||
command as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
error_category: classifyPluginCommandError(
|
||||
error,
|
||||
) as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
...telemetryFields,
|
||||
})
|
||||
// eslint-disable-next-line custom-rules/no-process-exit
|
||||
process.exit(1)
|
||||
}
|
||||
|
||||
/**
|
||||
* CLI command: Install a plugin non-interactively
|
||||
* @param plugin Plugin identifier (name or plugin@marketplace)
|
||||
* @param scope Installation scope: user, project, or local (defaults to 'user')
|
||||
*/
|
||||
export async function installPlugin(
|
||||
plugin: string,
|
||||
scope: InstallableScope = 'user',
|
||||
): Promise<void> {
|
||||
try {
|
||||
// biome-ignore lint/suspicious/noConsole:: intentional console output
|
||||
console.log(`Installing plugin "${plugin}"...`)
|
||||
|
||||
const result = await installPluginOp(plugin, scope)
|
||||
|
||||
if (!result.success) {
|
||||
throw new Error(result.message)
|
||||
}
|
||||
|
||||
// biome-ignore lint/suspicious/noConsole:: intentional console output
|
||||
console.log(`${figures.tick} ${result.message}`)
|
||||
|
||||
// _PROTO_* routes to PII-tagged plugin_name/marketplace_name BQ columns.
|
||||
// Unredacted plugin_id was previously logged to general-access
|
||||
// additional_metadata for all users — dropped in favor of the privileged
|
||||
// column route.
|
||||
const { name, marketplace } = parsePluginIdentifier(
|
||||
result.pluginId || plugin,
|
||||
)
|
||||
logEvent('tengu_plugin_installed_cli', {
|
||||
_PROTO_plugin_name:
|
||||
name as AnalyticsMetadata_I_VERIFIED_THIS_IS_PII_TAGGED,
|
||||
...(marketplace && {
|
||||
_PROTO_marketplace_name:
|
||||
marketplace as AnalyticsMetadata_I_VERIFIED_THIS_IS_PII_TAGGED,
|
||||
}),
|
||||
scope: (result.scope ||
|
||||
scope) as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
install_source:
|
||||
'cli-explicit' as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
...buildPluginTelemetryFields(name, marketplace, getManagedPluginNames()),
|
||||
})
|
||||
|
||||
// eslint-disable-next-line custom-rules/no-process-exit
|
||||
process.exit(0)
|
||||
} catch (error) {
|
||||
handlePluginCommandError(error, 'install', plugin)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* CLI command: Uninstall a plugin non-interactively
|
||||
* @param plugin Plugin name or plugin@marketplace identifier
|
||||
* @param scope Uninstall from scope: user, project, or local (defaults to 'user')
|
||||
*/
|
||||
export async function uninstallPlugin(
|
||||
plugin: string,
|
||||
scope: InstallableScope = 'user',
|
||||
keepData = false,
|
||||
): Promise<void> {
|
||||
try {
|
||||
const result = await uninstallPluginOp(plugin, scope, !keepData)
|
||||
|
||||
if (!result.success) {
|
||||
throw new Error(result.message)
|
||||
}
|
||||
|
||||
// biome-ignore lint/suspicious/noConsole:: intentional console output
|
||||
console.log(`${figures.tick} ${result.message}`)
|
||||
|
||||
const { name, marketplace } = parsePluginIdentifier(
|
||||
result.pluginId || plugin,
|
||||
)
|
||||
logEvent('tengu_plugin_uninstalled_cli', {
|
||||
_PROTO_plugin_name:
|
||||
name as AnalyticsMetadata_I_VERIFIED_THIS_IS_PII_TAGGED,
|
||||
...(marketplace && {
|
||||
_PROTO_marketplace_name:
|
||||
marketplace as AnalyticsMetadata_I_VERIFIED_THIS_IS_PII_TAGGED,
|
||||
}),
|
||||
scope: (result.scope ||
|
||||
scope) as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
...buildPluginTelemetryFields(name, marketplace, getManagedPluginNames()),
|
||||
})
|
||||
|
||||
// eslint-disable-next-line custom-rules/no-process-exit
|
||||
process.exit(0)
|
||||
} catch (error) {
|
||||
handlePluginCommandError(error, 'uninstall', plugin)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* CLI command: Enable a plugin non-interactively
|
||||
* @param plugin Plugin name or plugin@marketplace identifier
|
||||
* @param scope Optional scope. If not provided, finds the most specific scope for the current project.
|
||||
*/
|
||||
export async function enablePlugin(
|
||||
plugin: string,
|
||||
scope?: InstallableScope,
|
||||
): Promise<void> {
|
||||
try {
|
||||
const result = await enablePluginOp(plugin, scope)
|
||||
|
||||
if (!result.success) {
|
||||
throw new Error(result.message)
|
||||
}
|
||||
|
||||
// biome-ignore lint/suspicious/noConsole:: intentional console output
|
||||
console.log(`${figures.tick} ${result.message}`)
|
||||
|
||||
const { name, marketplace } = parsePluginIdentifier(
|
||||
result.pluginId || plugin,
|
||||
)
|
||||
logEvent('tengu_plugin_enabled_cli', {
|
||||
_PROTO_plugin_name:
|
||||
name as AnalyticsMetadata_I_VERIFIED_THIS_IS_PII_TAGGED,
|
||||
...(marketplace && {
|
||||
_PROTO_marketplace_name:
|
||||
marketplace as AnalyticsMetadata_I_VERIFIED_THIS_IS_PII_TAGGED,
|
||||
}),
|
||||
scope:
|
||||
result.scope as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
...buildPluginTelemetryFields(name, marketplace, getManagedPluginNames()),
|
||||
})
|
||||
|
||||
// eslint-disable-next-line custom-rules/no-process-exit
|
||||
process.exit(0)
|
||||
} catch (error) {
|
||||
handlePluginCommandError(error, 'enable', plugin)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* CLI command: Disable a plugin non-interactively
|
||||
* @param plugin Plugin name or plugin@marketplace identifier
|
||||
* @param scope Optional scope. If not provided, finds the most specific scope for the current project.
|
||||
*/
|
||||
export async function disablePlugin(
|
||||
plugin: string,
|
||||
scope?: InstallableScope,
|
||||
): Promise<void> {
|
||||
try {
|
||||
const result = await disablePluginOp(plugin, scope)
|
||||
|
||||
if (!result.success) {
|
||||
throw new Error(result.message)
|
||||
}
|
||||
|
||||
// biome-ignore lint/suspicious/noConsole:: intentional console output
|
||||
console.log(`${figures.tick} ${result.message}`)
|
||||
|
||||
const { name, marketplace } = parsePluginIdentifier(
|
||||
result.pluginId || plugin,
|
||||
)
|
||||
logEvent('tengu_plugin_disabled_cli', {
|
||||
_PROTO_plugin_name:
|
||||
name as AnalyticsMetadata_I_VERIFIED_THIS_IS_PII_TAGGED,
|
||||
...(marketplace && {
|
||||
_PROTO_marketplace_name:
|
||||
marketplace as AnalyticsMetadata_I_VERIFIED_THIS_IS_PII_TAGGED,
|
||||
}),
|
||||
scope:
|
||||
result.scope as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
...buildPluginTelemetryFields(name, marketplace, getManagedPluginNames()),
|
||||
})
|
||||
|
||||
// eslint-disable-next-line custom-rules/no-process-exit
|
||||
process.exit(0)
|
||||
} catch (error) {
|
||||
handlePluginCommandError(error, 'disable', plugin)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* CLI command: Disable all enabled plugins non-interactively
|
||||
*/
|
||||
export async function disableAllPlugins(): Promise<void> {
|
||||
try {
|
||||
const result = await disableAllPluginsOp()
|
||||
|
||||
if (!result.success) {
|
||||
throw new Error(result.message)
|
||||
}
|
||||
|
||||
// biome-ignore lint/suspicious/noConsole:: intentional console output
|
||||
console.log(`${figures.tick} ${result.message}`)
|
||||
|
||||
logEvent('tengu_plugin_disabled_all_cli', {})
|
||||
|
||||
// eslint-disable-next-line custom-rules/no-process-exit
|
||||
process.exit(0)
|
||||
} catch (error) {
|
||||
handlePluginCommandError(error, 'disable-all')
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* CLI command: Update a plugin non-interactively
|
||||
* @param plugin Plugin name or plugin@marketplace identifier
|
||||
* @param scope Scope to update
|
||||
*/
|
||||
export async function updatePluginCli(
|
||||
plugin: string,
|
||||
scope: PluginScope,
|
||||
): Promise<void> {
|
||||
try {
|
||||
writeToStdout(
|
||||
`Checking for updates for plugin "${plugin}" at ${scope} scope…\n`,
|
||||
)
|
||||
|
||||
const result = await updatePluginOp(plugin, scope)
|
||||
|
||||
if (!result.success) {
|
||||
throw new Error(result.message)
|
||||
}
|
||||
|
||||
writeToStdout(`${figures.tick} ${result.message}\n`)
|
||||
|
||||
if (!result.alreadyUpToDate) {
|
||||
const { name, marketplace } = parsePluginIdentifier(
|
||||
result.pluginId || plugin,
|
||||
)
|
||||
logEvent('tengu_plugin_updated_cli', {
|
||||
_PROTO_plugin_name:
|
||||
name as AnalyticsMetadata_I_VERIFIED_THIS_IS_PII_TAGGED,
|
||||
...(marketplace && {
|
||||
_PROTO_marketplace_name:
|
||||
marketplace as AnalyticsMetadata_I_VERIFIED_THIS_IS_PII_TAGGED,
|
||||
}),
|
||||
old_version: (result.oldVersion ||
|
||||
'unknown') as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
new_version: (result.newVersion ||
|
||||
'unknown') as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
...buildPluginTelemetryFields(
|
||||
name,
|
||||
marketplace,
|
||||
getManagedPluginNames(),
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
await gracefulShutdown(0)
|
||||
} catch (error) {
|
||||
handlePluginCommandError(error, 'update', plugin)
|
||||
}
|
||||
}
|
||||
1088
src/services/plugins/pluginOperations.ts
Normal file
1088
src/services/plugins/pluginOperations.ts
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user