From 26b79312f9db7a958bef08cd0d286bc38c347fce Mon Sep 17 00:00:00 2001 From: Hunter Bown Date: Tue, 28 Apr 2026 00:26:00 -0500 Subject: [PATCH] feat(runtime): #133 add fork_at_user_message for backtrack rewind MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds `RuntimeThreadManager::fork_at_user_message(id, depth_from_tail)` — a sibling of the existing `fork_thread` that drops every turn from the Nth-from-tail user message onward and returns the dropped user input so the caller can pre-populate the composer. The existing `fork_thread` is left untouched. The new helper mirrors its copy loop but stops short of the cutoff turn, emitting a `thread.forked` event with backtrack provenance fields. Includes unit tests covering depth=0, depth=1, out-of-range error, and source-thread non-mutation. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/tui/src/runtime_threads.rs | 306 ++++++++++++++++++++++++++++++ 1 file changed, 306 insertions(+) diff --git a/crates/tui/src/runtime_threads.rs b/crates/tui/src/runtime_threads.rs index ed942108..9680c751 100644 --- a/crates/tui/src/runtime_threads.rs +++ b/crates/tui/src/runtime_threads.rs @@ -843,6 +843,125 @@ impl RuntimeThreadManager { Ok(forked) } + /// Fork a thread, dropping every turn from the Nth-from-tail user + /// message onward (issue #133 — Esc-Esc backtrack). + /// + /// `depth_from_tail` selects which user turn to roll back *to*: + /// + /// - `0` — drop the most recent turn (the freshest user message and + /// everything after it) + /// - `1` — drop the two most recent turns (rewind one further) + /// - …and so on + /// + /// Returns a tuple of `(forked_thread, original_user_text)` where the + /// second element is the `detail` of the first `UserMessage` item in + /// the *first dropped* turn — i.e. the input the user typed to start + /// that turn — so the caller can pre-populate the composer with it. + /// `None` when no detail was recorded (defensive — every persisted + /// `UserMessage` since v0.6 carries a detail string). + /// + /// Counts user turns by iterating `list_turns_for_thread` (sorted + /// oldest → newest) backwards. A turn is counted as a "user turn" + /// when at least one of its items has `kind == + /// TurnItemKind::UserMessage`. Steered turns (which append additional + /// `UserMessage` items) still count as one turn — backtrack rewinds + /// at the turn boundary, not at the steer boundary. + /// + /// Errors: + /// - `depth_from_tail` exceeds the number of user turns + /// - source thread not found + #[allow(dead_code)] // exposed for the runtime/HTTP fork-on-backtrack path; the in-TUI Esc-Esc flow trims `App` state directly. Issue #133. + pub async fn fork_at_user_message( + &self, + id: &str, + depth_from_tail: usize, + ) -> Result<(ThreadRecord, Option)> { + let source = self.get_thread(id).await?; + let source_turns = self.store.list_turns_for_thread(&source.id)?; + + // Walk turns from newest to oldest. For each turn, ask: does it + // contain a UserMessage item? If yes, it counts toward the depth. + let mut user_turn_indices: Vec = Vec::new(); + for (idx, turn) in source_turns.iter().enumerate().rev() { + let items = self.store.list_items_for_turn(&turn.id)?; + if items + .iter() + .any(|item| item.kind == TurnItemKind::UserMessage) + { + user_turn_indices.push(idx); + } + } + if depth_from_tail >= user_turn_indices.len() { + bail!( + "fork_at_user_message: depth {} exceeds {} user turn(s)", + depth_from_tail, + user_turn_indices.len() + ); + } + // `user_turn_indices` is newest-first because we iterated in + // reverse, so the Nth element is exactly the Nth-from-tail user + // turn in the original chronological list. + let target_turn_idx = user_turn_indices[depth_from_tail]; + let target_turn_id = source_turns[target_turn_idx].id.clone(); + + // Pull the original user-message text out of the dropped turn so + // the caller can drop it back into the composer. + let target_items = self.store.list_items_for_turn(&target_turn_id)?; + let original_user_text = target_items + .iter() + .find(|item| item.kind == TurnItemKind::UserMessage) + .and_then(|item| item.detail.clone()); + + // Copy turns strictly before `target_turn_idx` into a new thread. + // Mirrors `fork_thread` but stops at the cutoff instead of copying + // every turn. Kept structurally close so future parity reviews + // can spot drift between the two paths. + let mut forked = source.clone(); + let now = Utc::now(); + forked.id = format!("thr_{}", &Uuid::new_v4().to_string()[..8]); + forked.created_at = now; + forked.updated_at = now; + forked.latest_turn_id = None; + forked.archived = false; + self.store.save_thread(&forked)?; + + for source_turn in source_turns.iter().take(target_turn_idx) { + let mut cloned_turn = source_turn.clone(); + cloned_turn.id = format!("turn_{}", &Uuid::new_v4().to_string()[..8]); + cloned_turn.thread_id = forked.id.clone(); + cloned_turn.item_ids.clear(); + self.store.save_turn(&cloned_turn)?; + + let items = self.store.list_items_for_turn(&source_turn.id)?; + for item in items { + let mut cloned_item = item.clone(); + cloned_item.id = format!("item_{}", &Uuid::new_v4().to_string()[..8]); + cloned_item.turn_id = cloned_turn.id.clone(); + self.store.save_item(&cloned_item)?; + cloned_turn.item_ids.push(cloned_item.id.clone()); + } + self.store.save_turn(&cloned_turn)?; + forked.latest_turn_id = Some(cloned_turn.id.clone()); + forked.updated_at = now; + self.store.save_thread(&forked)?; + } + + self.emit_event( + &forked.id, + None, + None, + "thread.forked", + json!({ + "thread": forked, + "source_thread_id": source.id, + "backtrack_depth_from_tail": depth_from_tail, + "dropped_turn_id": target_turn_id, + }), + ) + .await?; + Ok((forked, original_user_text)) + } + /// Seed a thread with messages from a saved session so subsequent turns /// continue with the prior conversation context. pub async fn seed_thread_from_messages( @@ -3978,4 +4097,191 @@ mod tests { assert_eq!(hints.len(), 1); assert_eq!(hints[0].status, AgentRebindStatus::Completed); } + + /// Helper for the `fork_at_user_message` tests: write a sequence of + /// (user, assistant) turns under the given thread id. Each turn gets + /// one UserMessage item carrying `user_text` in `detail` plus one + /// AgentMessage item. Turn `created_at` is monotonically increasing + /// so the chronological sort in `list_turns_for_thread` is stable. + fn seed_turns_with_user_messages( + manager: &RuntimeThreadManager, + thread_id: &str, + user_texts: &[&str], + ) -> Result> { + let mut turn_ids = Vec::new(); + let base = Utc::now(); + for (offset, text) in user_texts.iter().enumerate() { + let created_at = base + chrono::Duration::milliseconds(offset as i64); + let turn_id = format!("turn_test_{offset}"); + let user_item_id = format!("item_user_{offset}"); + let asst_item_id = format!("item_asst_{offset}"); + manager.store.save_item(&TurnItemRecord { + schema_version: CURRENT_RUNTIME_SCHEMA_VERSION, + id: user_item_id.clone(), + turn_id: turn_id.clone(), + kind: TurnItemKind::UserMessage, + status: TurnItemLifecycleStatus::Completed, + summary: (*text).to_string(), + detail: Some((*text).to_string()), + artifact_refs: Vec::new(), + started_at: Some(created_at), + ended_at: Some(created_at), + })?; + manager.store.save_item(&TurnItemRecord { + schema_version: CURRENT_RUNTIME_SCHEMA_VERSION, + id: asst_item_id.clone(), + turn_id: turn_id.clone(), + kind: TurnItemKind::AgentMessage, + status: TurnItemLifecycleStatus::Completed, + summary: format!("reply {offset}"), + detail: Some(format!("reply {offset}")), + artifact_refs: Vec::new(), + started_at: Some(created_at), + ended_at: Some(created_at), + })?; + manager.store.save_turn(&TurnRecord { + schema_version: CURRENT_RUNTIME_SCHEMA_VERSION, + id: turn_id.clone(), + thread_id: thread_id.to_string(), + status: RuntimeTurnStatus::Completed, + input_summary: (*text).to_string(), + created_at, + started_at: Some(created_at), + ended_at: Some(created_at), + duration_ms: Some(0), + usage: None, + error: None, + item_ids: vec![user_item_id, asst_item_id], + steer_count: 0, + })?; + turn_ids.push(turn_id); + } + Ok(turn_ids) + } + + #[tokio::test] + async fn fork_at_user_message_drops_tail_and_returns_user_text() -> Result<()> { + // Seed three completed user/assistant turns. Backtracking with + // depth=0 should drop only the most recent turn ("third") and + // hand back its original text so the caller can refill the + // composer. + let manager = test_manager(test_runtime_dir())?; + let thread = manager + .create_thread(CreateThreadRequest { + model: None, + workspace: None, + mode: None, + allow_shell: None, + trust_mode: None, + auto_approve: None, + archived: false, + system_prompt: None, + }) + .await?; + seed_turns_with_user_messages(&manager, &thread.id, &["first", "second", "third"])?; + + let (forked, original_text) = manager.fork_at_user_message(&thread.id, 0).await?; + assert_eq!(original_text.as_deref(), Some("third")); + assert_ne!(forked.id, thread.id); + + let forked_turns = manager.store.list_turns_for_thread(&forked.id)?; + assert_eq!( + forked_turns.len(), + 2, + "depth=0 should drop the most recent turn" + ); + let summaries: Vec<&str> = forked_turns + .iter() + .map(|t| t.input_summary.as_str()) + .collect(); + assert_eq!(summaries, vec!["first", "second"]); + Ok(()) + } + + #[tokio::test] + async fn fork_at_user_message_depth_one_drops_two_turns() -> Result<()> { + let manager = test_manager(test_runtime_dir())?; + let thread = manager + .create_thread(CreateThreadRequest { + model: None, + workspace: None, + mode: None, + allow_shell: None, + trust_mode: None, + auto_approve: None, + archived: false, + system_prompt: None, + }) + .await?; + seed_turns_with_user_messages(&manager, &thread.id, &["a", "b", "c", "d"])?; + + let (forked, original_text) = manager.fork_at_user_message(&thread.id, 1).await?; + assert_eq!(original_text.as_deref(), Some("c")); + let forked_turns = manager.store.list_turns_for_thread(&forked.id)?; + let summaries: Vec<&str> = forked_turns + .iter() + .map(|t| t.input_summary.as_str()) + .collect(); + assert_eq!(summaries, vec!["a", "b"]); + Ok(()) + } + + #[tokio::test] + async fn fork_at_user_message_out_of_range_errors() -> Result<()> { + let manager = test_manager(test_runtime_dir())?; + let thread = manager + .create_thread(CreateThreadRequest { + model: None, + workspace: None, + mode: None, + allow_shell: None, + trust_mode: None, + auto_approve: None, + archived: false, + system_prompt: None, + }) + .await?; + seed_turns_with_user_messages(&manager, &thread.id, &["only"])?; + + let err = manager.fork_at_user_message(&thread.id, 5).await.err(); + assert!(err.is_some(), "depth past the end should bail out"); + Ok(()) + } + + #[tokio::test] + async fn fork_at_user_message_does_not_mutate_source() -> Result<()> { + // The source thread must be untouched: turns still present, items + // still present, latest_turn_id still pointing at the original + // tail. Backtrack creates a sibling, never edits in place. + let manager = test_manager(test_runtime_dir())?; + let thread = manager + .create_thread(CreateThreadRequest { + model: None, + workspace: None, + mode: None, + allow_shell: None, + trust_mode: None, + auto_approve: None, + archived: false, + system_prompt: None, + }) + .await?; + let turn_ids = seed_turns_with_user_messages(&manager, &thread.id, &["x", "y", "z"])?; + + let _ = manager.fork_at_user_message(&thread.id, 0).await?; + + let source_turns = manager.store.list_turns_for_thread(&thread.id)?; + assert_eq!( + source_turns.len(), + 3, + "source thread must still hold every turn after fork" + ); + for tid in &turn_ids { + assert!( + manager.store.load_turn(tid).is_ok(), + "turn {tid} must remain on disk" + ); + } + Ok(()) + } }