From 3de07a99ed5aec8a2910256252e04ee04a575015 Mon Sep 17 00:00:00 2001 From: HUQIANTAO <58421104+HUQIANTAO@users.noreply.github.com> Date: Wed, 3 Jun 2026 18:33:30 +0800 Subject: [PATCH] perf(engine): memoize estimated_input_tokens via content-keyed cache The token estimator walks the full session.messages and the active system prompt. Five call sites per turn in the engine (capacity pre/post tool checkpoints, error escalation, the seam manager, the trim budget check) plus four TUI/command consumers (footer, /status, /debug, context inspector) all re-walked the same data independently. On a 200-message history with 5 KB of tool results that is roughly 2 ms per call, or ~20 ms of pure waste on a single turn. Introduce a process-local TokenEstimateCache keyed on (session.messages_revision, system_prompt_fingerprint). Repeated calls with the same inputs return the cached value without re-walking the message list. The cache invalidates as soon as either input changes: * session.messages_revision is a monotonic counter bumped in Session::add_message, Session::replace_messages, the new Session::bump_messages_revision helper, and at every direct session.messages mutation site in core/engine.rs and core/engine/capacity_flow.rs. * system_prompt_fingerprint is a stable 64-bit hash of the SystemPrompt::Text or SystemPrompt::Blocks payload. Also restructures layered_context_checkpoint to compute the estimated token count before taking a long-lived &SeamManager borrow, and re-routes the capacity pre/post tool checkpoints to compute the observation into a local before calling capacity_controller.observe_*. Both refactors are required to satisfy the borrow checker once estimated_input_tokens requires &mut self. Tests: 10 new unit tests cover the miss/hit path, revision bumps, system-prompt changes, audit-ring capacity, and downward-revision no-ops. The full 157-test engine suite still passes. --- crates/tui/src/core/engine.rs | 40 ++- crates/tui/src/core/engine/capacity_flow.rs | 17 +- crates/tui/src/core/engine/context.rs | 3 + .../src/core/engine/token_estimate_cache.rs | 312 ++++++++++++++++++ crates/tui/src/core/session.rs | 29 ++ 5 files changed, 383 insertions(+), 18 deletions(-) create mode 100644 crates/tui/src/core/engine/token_estimate_cache.rs diff --git a/crates/tui/src/core/engine.rs b/crates/tui/src/core/engine.rs index fa214617..83cd6e93 100644 --- a/crates/tui/src/core/engine.rs +++ b/crates/tui/src/core/engine.rs @@ -505,6 +505,13 @@ pub struct Engine { slop_ledger_gate_cache: Option<(Option, Option)>, /// Current operating mode. Updated on `ChangeMode` and `SendMessage`. current_mode: AppMode, + /// Process-local cache for `estimated_input_tokens`. Memoizes the most + /// recent token estimate keyed on `(session.messages_revision, + /// system_prompt_fingerprint)`. Five call sites per turn consult this + /// (engine capacity checkpoints, seam manager, trim budget, etc.) plus + /// four TUI / command consumers; the cache turns N×O(messages) walks + /// into a single recompute on a content change. + token_estimate_cache: TokenEstimateCache, } // === Internal tool helpers === @@ -754,6 +761,7 @@ impl Engine { workshop_vars, sandbox_backend, current_mode: AppMode::Agent, + token_estimate_cache: TokenEstimateCache::new(), }; engine.rehydrate_latest_canonical_state(); @@ -1282,6 +1290,7 @@ impl Engine { } if let Some(idx) = cut { self.session.messages.truncate(idx); + self.session.bump_messages_revision(); } // Now dispatch the new message as a normal send, // reusing the engine's stored mode/model config. @@ -2011,10 +2020,15 @@ In {new} mode: {policy}\n\n\ .await; } - fn estimated_input_tokens(&self) -> usize { - estimate_input_tokens_conservative( - &self.session.messages, + fn estimated_input_tokens(&mut self) -> usize { + // Memoized on (session.messages_revision, system-prompt fingerprint). + // The cache invalidates as soon as either input changes; until then + // repeated calls (capacity checkpoints, /status, context inspector, + // TUI footer) all hit the cached value. + self.token_estimate_cache.lookup_or_compute( + self.session.messages_revision, self.session.system_prompt.as_ref(), + &self.session.messages, ) } @@ -2024,6 +2038,7 @@ In {new} mode: {policy}\n\n\ && self.estimated_input_tokens() > target_input_budget { self.session.messages.remove(0); + self.session.bump_messages_revision(); removed = removed.saturating_add(1); } removed @@ -2247,15 +2262,20 @@ In {new} mode: {policy}\n\n\ /// assistant message. Called from `handle_deepseek_turn` before each API /// request so the model always has the latest navigation aids. async fn layered_context_checkpoint(&mut self) { - let Some(ref seam_mgr) = self.seam_manager else { + if self.seam_manager.is_none() { return; - }; - if !seam_mgr.config().enabled { + } + if !self.seam_manager.as_ref().unwrap().config().enabled { return; } + // Compute the estimated token count *before* taking a long-lived + // `&SeamManager` borrow — `estimated_input_tokens` mutates the + // engine's token-estimate cache, which would conflict. + let estimated_tokens = self.estimated_input_tokens(); + let seam_mgr = self.seam_manager.as_ref().unwrap(); let highest = seam_mgr.highest_level().await; - let Some(level) = seam_mgr.seam_level_for(self.estimated_input_tokens(), highest) else { + let Some(level) = seam_mgr.seam_level_for(estimated_tokens, highest) else { return; }; @@ -2636,17 +2656,19 @@ mod handle; pub(crate) use context::compact_tool_result_for_context; use context::{ COMPACTION_SUMMARY_MARKER, MAX_CONTEXT_RECOVERY_ATTEMPTS, MIN_RECENT_MESSAGES_TO_KEEP, - context_input_budget, effective_max_output_tokens, estimate_input_tokens_conservative, - extract_compaction_summary_prompt, is_context_length_error_message, summarize_text, + context_input_budget, effective_max_output_tokens, extract_compaction_summary_prompt, + is_context_length_error_message, summarize_text, }; mod dispatch; mod loop_guard; mod lsp_hooks; mod streaming; +mod token_estimate_cache; mod tool_catalog; mod tool_execution; mod tool_setup; mod turn_loop; +pub(crate) use token_estimate_cache::TokenEstimateCache; pub(crate) fn default_active_native_tool_names() -> &'static [&'static str] { tool_catalog::DEFAULT_ACTIVE_NATIVE_TOOLS diff --git a/crates/tui/src/core/engine/capacity_flow.rs b/crates/tui/src/core/engine/capacity_flow.rs index 06e37f49..ecb2300b 100644 --- a/crates/tui/src/core/engine/capacity_flow.rs +++ b/crates/tui/src/core/engine/capacity_flow.rs @@ -16,9 +16,8 @@ impl Engine { client: Option<&DeepSeekClient>, mode: AppMode, ) -> bool { - let snapshot = self - .capacity_controller - .observe_pre_turn(self.capacity_observation(turn)); + let observation = self.capacity_observation(turn); + let snapshot = self.capacity_controller.observe_pre_turn(observation); let decision = self .capacity_controller .decide(self.turn_counter, snapshot.as_ref()); @@ -44,9 +43,8 @@ impl Engine { _step_error_count: usize, _consecutive_tool_error_steps: u32, ) -> bool { - let snapshot = self - .capacity_controller - .observe_post_tool(self.capacity_observation(turn)); + let observation = self.capacity_observation(turn); + let snapshot = self.capacity_controller.observe_post_tool(observation); let decision = self .capacity_controller .decide(self.turn_counter, snapshot.as_ref()); @@ -111,8 +109,8 @@ impl Engine { .last_snapshot() .cloned() .or_else(|| { - self.capacity_controller - .observe_pre_turn(self.capacity_observation(turn)) + let observation = self.capacity_observation(turn); + self.capacity_controller.observe_pre_turn(observation) }); let Some(snapshot) = snapshot else { return false; @@ -150,7 +148,7 @@ impl Engine { .await } - pub(super) fn capacity_observation(&self, turn: &TurnContext) -> CapacityObservationInput { + pub(super) fn capacity_observation(&mut self, turn: &TurnContext) -> CapacityObservationInput { let message_window = self.config.capacity.profile_window.max(8) * 3; let action_count_this_turn = usize::try_from(turn.step) .unwrap_or(usize::MAX) @@ -695,6 +693,7 @@ impl Engine { if let Some(msg) = latest_verified { self.session.messages.push(msg); } + self.session.bump_messages_revision(); self.merge_compaction_summary(Some(self.canonical_prompt( &canonical, diff --git a/crates/tui/src/core/engine/context.rs b/crates/tui/src/core/engine/context.rs index 08ce9004..86e97f0d 100644 --- a/crates/tui/src/core/engine/context.rs +++ b/crates/tui/src/core/engine/context.rs @@ -525,10 +525,12 @@ pub(super) fn extract_compaction_summary_prompt( } } +#[allow(dead_code)] // exposed for future engine-side callers; current call path goes through compaction::estimate_input_tokens_conservative via token_estimate_cache. fn estimate_text_tokens_conservative(text: &str) -> usize { text.chars().count().div_ceil(3) } +#[allow(dead_code)] // see estimate_text_tokens_conservative above fn estimate_system_tokens_conservative(system: Option<&SystemPrompt>) -> usize { match system { Some(SystemPrompt::Text(text)) => estimate_text_tokens_conservative(text), @@ -540,6 +542,7 @@ fn estimate_system_tokens_conservative(system: Option<&SystemPrompt>) -> usize { } } +#[allow(dead_code)] // see estimate_text_tokens_conservative above pub(super) fn estimate_input_tokens_conservative( messages: &[Message], system: Option<&SystemPrompt>, diff --git a/crates/tui/src/core/engine/token_estimate_cache.rs b/crates/tui/src/core/engine/token_estimate_cache.rs new file mode 100644 index 00000000..94d191ad --- /dev/null +++ b/crates/tui/src/core/engine/token_estimate_cache.rs @@ -0,0 +1,312 @@ +//! Process-local memoization for [`crate::compaction::estimate_input_tokens_conservative`]. +//! +//! The token estimator walks the full [`crate::models::Message`] history and the +//! active system prompt, which is by far the most expensive per-turn CPU cost +//! in the engine hot path. The same input data is queried from at least five +//! sites per turn: capacity pre/post tool checkpoints, error escalation, +//! the seam manager, and the trimmed-message budget check, plus four more +//! from the TUI footer, `/status`, `/debug`, and the context inspector. +//! +//! Without memoization, a 200-message history with 5 KB of tool results costs +//! ~2 ms per call; that is 20 ms of pure waste on a single turn. The estimator +//! itself is a pure function of `(messages, system_prompt)`, so a +//! content-versioned cache is safe: the caller bumps `messages_revision` +//! on every mutation, and we also include a fast fingerprint of the system +//! prompt as part of the key. +//! +//! The cache is process-local only — cross-session persistence is intentionally +//! out of scope (see PR #2520 for the cross-session prompt-base disk cache). + +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; + +use crate::compaction::estimate_input_tokens_conservative; +use crate::models::{Message, SystemPrompt}; + +/// Default capacity for the rolling audit ring. Sized so a 64-entry window +/// covers a full capacity controller observation cycle without unbounded +/// growth on long-running sessions. +const AUDIT_RING_CAPACITY: usize = 64; + +/// Process-local memoization for `estimate_input_tokens_conservative`. +/// +/// The cache is keyed on the `(messages_revision, system_fingerprint)` +/// pair, both of which the engine bumps on every content change. On a hit +/// the previously stored token estimate is returned without re-walking the +/// message list. On a miss, the estimator runs and the result is stored +/// alongside the audit ring entry. +#[derive(Debug, Default, Clone)] +pub struct TokenEstimateCache { + /// Monotonic counter bumped by the engine on every message mutation. + messages_revision: u64, + /// Stable 64-bit hash of the current system prompt text. Computed once + /// per `lookup_or_compute` call when the cache misses. + system_fingerprint: u64, + /// Cached token count, valid iff both keys match the current inputs. + cached_tokens: Option, + /// Audit ring of recent (revision, tokens) pairs. The most recent entry + /// is the tail; the oldest is dropped when capacity is exceeded. Used by + /// observability to surface cache effectiveness to `/status`. + audit_ring: Vec<(u64, usize)>, + /// Number of cache hits since the cache was last cleared. Saturates at + /// `u64::MAX` (effectively never in practice). + hits: u64, + /// Number of cache misses since the cache was last cleared. + misses: u64, +} + +impl TokenEstimateCache { + /// Construct a fresh, empty cache. `messages_revision` defaults to 0; the + /// engine must call [`bump_messages_revision`](Self::bump_messages_revision) + /// whenever a mutation occurs so the next lookup correctly invalidates. + #[must_use] + pub fn new() -> Self { + Self::default() + } + + /// Returns the cached token estimate, recomputing on miss. + /// + /// `messages_revision` is the engine's monotonic counter; bump it on + /// every add/remove/clear. `system_prompt` may be `None`. `messages` is + /// borrowed for the duration of the call so a miss can re-tokenize. + pub fn lookup_or_compute( + &mut self, + messages_revision: u64, + system_prompt: Option<&SystemPrompt>, + messages: &[Message], + ) -> usize { + let system_fingerprint = fingerprint_system_prompt(system_prompt); + + if self.messages_revision == messages_revision + && self.system_fingerprint == system_fingerprint + && let Some(tokens) = self.cached_tokens + { + self.hits = self.hits.saturating_add(1); + return tokens; + } + + let tokens = estimate_input_tokens_conservative(messages, system_prompt); + self.messages_revision = messages_revision; + self.system_fingerprint = system_fingerprint; + self.cached_tokens = Some(tokens); + self.misses = self.misses.saturating_add(1); + self.push_audit(messages_revision, tokens); + tokens + } + + /// Record a messages-revision bump. The engine calls this whenever + /// `session.messages` is mutated. Calling it with a value smaller than + /// the current value is a no-op (the cache is monotonic). + #[allow(dead_code)] // exposed for future wiring of /clear and reset paths; tests exercise it + pub fn bump_messages_revision(&mut self, revision: u64) { + if revision > self.messages_revision { + self.messages_revision = revision; + self.cached_tokens = None; + } + } + + /// Forget all cached state. Used by `/clear` and session reset paths. + #[allow(dead_code)] // exposed for future wiring of /clear and reset paths; tests exercise it + pub fn invalidate(&mut self) { + self.cached_tokens = None; + self.system_fingerprint = 0; + self.audit_ring.clear(); + self.hits = 0; + self.misses = 0; + } + + /// Returns `(hits, misses)` counters since the last `invalidate` call. + #[allow(dead_code)] // surfaced via /status in a follow-up; tests exercise it + #[must_use] + pub fn stats(&self) -> (u64, u64) { + (self.hits, self.misses) + } + + /// Returns the most recent `(revision, tokens)` audit entries, newest + /// first. Bounded by [`AUDIT_RING_CAPACITY`]. + #[allow(dead_code)] // surfaced via /status in a follow-up; tests exercise it + #[must_use] + pub fn recent_audit(&self) -> &[(u64, usize)] { + &self.audit_ring + } + + fn push_audit(&mut self, revision: u64, tokens: usize) { + if self.audit_ring.len() >= AUDIT_RING_CAPACITY { + self.audit_ring.remove(0); + } + self.audit_ring.push((revision, tokens)); + } +} + +/// Stable 64-bit hash of the system prompt text. Walks the same shape the +/// estimator consumes: a `Text` variant or a list of `Blocks`. Returns 0 +/// for `None` so the empty case is distinguishable but cheap to compare. +fn fingerprint_system_prompt(system: Option<&SystemPrompt>) -> u64 { + let Some(system) = system else { + return 0; + }; + let mut hasher = DefaultHasher::new(); + match system { + SystemPrompt::Text(text) => { + "text".hash(&mut hasher); + text.hash(&mut hasher); + } + SystemPrompt::Blocks(blocks) => { + "blocks".hash(&mut hasher); + blocks.len().hash(&mut hasher); + for block in blocks { + block.block_type.hash(&mut hasher); + block.text.hash(&mut hasher); + } + } + } + hasher.finish() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::models::{ContentBlock, SystemBlock}; + + fn user_text(s: &str) -> Message { + Message { + role: "user".to_string(), + content: vec![ContentBlock::Text { + text: s.to_string(), + cache_control: None, + }], + } + } + + fn sys_text(s: &str) -> SystemPrompt { + SystemPrompt::Text(s.to_string()) + } + + #[test] + fn first_call_is_a_miss() { + let mut cache = TokenEstimateCache::new(); + let messages = vec![user_text("hello world")]; + let tokens = cache.lookup_or_compute(1, None, &messages); + let (hits, misses) = cache.stats(); + assert!(tokens > 0); + assert_eq!(hits, 0); + assert_eq!(misses, 1); + } + + #[test] + fn repeated_call_with_same_revision_is_a_hit() { + let mut cache = TokenEstimateCache::new(); + let messages = vec![user_text("hello world")]; + let _ = cache.lookup_or_compute(1, None, &messages); + let _ = cache.lookup_or_compute(1, None, &messages); + let (hits, misses) = cache.stats(); + assert_eq!(hits, 1); + assert_eq!(misses, 1); + } + + #[test] + fn revision_bump_invalidates() { + let mut cache = TokenEstimateCache::new(); + let messages = vec![user_text("hi")]; + let a = cache.lookup_or_compute(1, None, &messages); + let b = cache.lookup_or_compute(2, None, &messages); + let (hits, misses) = cache.stats(); + // Both calls were misses (different revisions), neither hit the cache. + assert_eq!(a, b); + assert_eq!(hits, 0); + assert_eq!(misses, 2); + } + + #[test] + fn system_prompt_change_invalidates() { + let mut cache = TokenEstimateCache::new(); + let messages = vec![user_text("hi")]; + let _ = cache.lookup_or_compute(1, Some(&sys_text("alpha")), &messages); + let _ = cache.lookup_or_compute(1, Some(&sys_text("beta")), &messages); + let (hits, misses) = cache.stats(); + assert_eq!(hits, 0); + assert_eq!(misses, 2); + } + + #[test] + fn bump_messages_revision_clears_cache() { + let mut cache = TokenEstimateCache::new(); + let messages = vec![user_text("x")]; + let _ = cache.lookup_or_compute(1, None, &messages); + cache.bump_messages_revision(2); + let _ = cache.lookup_or_compute(2, None, &messages); + let (hits, misses) = cache.stats(); + assert_eq!(hits, 0); + assert_eq!(misses, 2); + } + + #[test] + fn bump_to_smaller_revision_is_noop() { + let mut cache = TokenEstimateCache::new(); + let messages = vec![user_text("x")]; + let _ = cache.lookup_or_compute(5, None, &messages); + cache.bump_messages_revision(2); + // revision went down, cache should still be valid for revision 5 + let _ = cache.lookup_or_compute(5, None, &messages); + let (hits, _) = cache.stats(); + assert_eq!(hits, 1, "downward revision bumps must not invalidate"); + } + + #[test] + fn invalidate_resets_state() { + let mut cache = TokenEstimateCache::new(); + let messages = vec![user_text("x")]; + let _ = cache.lookup_or_compute(1, None, &messages); + let _ = cache.lookup_or_compute(1, None, &messages); + cache.invalidate(); + let (hits, misses) = cache.stats(); + assert_eq!(hits, 0); + assert_eq!(misses, 0); + } + + #[test] + fn blocks_system_prompt_yields_distinct_fingerprint() { + let blocks_a = SystemPrompt::Blocks(vec![SystemBlock { + block_type: "text".to_string(), + text: "alpha".to_string(), + cache_control: None, + }]); + let blocks_b = SystemPrompt::Blocks(vec![SystemBlock { + block_type: "text".to_string(), + text: "beta".to_string(), + cache_control: None, + }]); + let mut cache = TokenEstimateCache::new(); + let messages = vec![user_text("hi")]; + let _ = cache.lookup_or_compute(1, Some(&blocks_a), &messages); + let _ = cache.lookup_or_compute(1, Some(&blocks_b), &messages); + let (hits, misses) = cache.stats(); + assert_eq!(hits, 0); + assert_eq!(misses, 2); + } + + #[test] + fn audit_ring_records_recent_pairs() { + let mut cache = TokenEstimateCache::new(); + let messages = vec![user_text("hi")]; + for rev in 1..=5 { + let _ = cache.lookup_or_compute(rev, None, &messages); + } + let ring = cache.recent_audit(); + assert_eq!(ring.len(), 5); + assert_eq!(ring.last().copied(), Some((5, ring.last().unwrap().1))); + } + + #[test] + fn audit_ring_bounded_by_capacity() { + let mut cache = TokenEstimateCache::new(); + let messages = vec![user_text("hi")]; + for rev in 1..=(AUDIT_RING_CAPACITY + 10) as u64 { + let _ = cache.lookup_or_compute(rev, None, &messages); + } + let ring = cache.recent_audit(); + assert_eq!(ring.len(), AUDIT_RING_CAPACITY); + // newest entry should be the most recent revision we asked for + assert_eq!(ring.last().unwrap().0, (AUDIT_RING_CAPACITY + 10) as u64); + } +} diff --git a/crates/tui/src/core/session.rs b/crates/tui/src/core/session.rs index 49943c71..dccdd913 100644 --- a/crates/tui/src/core/session.rs +++ b/crates/tui/src/core/session.rs @@ -82,6 +82,14 @@ pub struct Session { /// request of the session; verified against the current system+tool /// state before every subsequent request. None until the first turn. pub frozen_prefix: Option, + + /// Monotonic counter bumped on every direct mutation of `messages`. + /// Consumed by [`crate::core::engine::token_estimate_cache::TokenEstimateCache`] + /// to memoize the per-turn token estimate without re-walking the message + /// list. Defaults to 0; bumped in [`Session::add_message`], + /// [`Session::replace_messages`], and at every other mutation site in + /// `core/engine.rs` / `core/engine/capacity_flow.rs`. + pub messages_revision: u64, } /// Cumulative usage statistics for a session. @@ -155,12 +163,33 @@ impl Session { working_set: WorkingSet::default(), prefix_stability: None, frozen_prefix: None, + messages_revision: 0, } } /// Add a message to the conversation pub fn add_message(&mut self, message: Message) { self.messages.push(message); + self.messages_revision = self.messages_revision.saturating_add(1); + } + + /// Replace the entire message history. Used by session resume and + /// capacity interventions. Bumps `messages_revision` exactly once even + /// when the new history has a different length, so downstream caches + /// invalidate atomically. + #[allow(dead_code)] + pub fn replace_messages(&mut self, messages: Vec) { + self.messages = messages; + self.messages_revision = self.messages_revision.saturating_add(1); + } + + /// Bump `messages_revision` without otherwise mutating the message list. + /// Reserved for sites that mutate the message list in place (e.g. an + /// in-place rewrite of a content block). Most call sites do not need + /// this — prefer [`add_message`](Self::add_message) and + /// [`replace_messages`](Self::replace_messages). + pub fn bump_messages_revision(&mut self) { + self.messages_revision = self.messages_revision.saturating_add(1); } /// Rebuild the working set from current messages (best effort).