feat(runtime): #133 add fork_at_user_message for backtrack rewind
Adds `RuntimeThreadManager::fork_at_user_message(id, depth_from_tail)` — a sibling of the existing `fork_thread` that drops every turn from the Nth-from-tail user message onward and returns the dropped user input so the caller can pre-populate the composer. The existing `fork_thread` is left untouched. The new helper mirrors its copy loop but stops short of the cutoff turn, emitting a `thread.forked` event with backtrack provenance fields. Includes unit tests covering depth=0, depth=1, out-of-range error, and source-thread non-mutation. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -843,6 +843,125 @@ impl RuntimeThreadManager {
|
||||
Ok(forked)
|
||||
}
|
||||
|
||||
/// Fork a thread, dropping every turn from the Nth-from-tail user
|
||||
/// message onward (issue #133 — Esc-Esc backtrack).
|
||||
///
|
||||
/// `depth_from_tail` selects which user turn to roll back *to*:
|
||||
///
|
||||
/// - `0` — drop the most recent turn (the freshest user message and
|
||||
/// everything after it)
|
||||
/// - `1` — drop the two most recent turns (rewind one further)
|
||||
/// - …and so on
|
||||
///
|
||||
/// Returns a tuple of `(forked_thread, original_user_text)` where the
|
||||
/// second element is the `detail` of the first `UserMessage` item in
|
||||
/// the *first dropped* turn — i.e. the input the user typed to start
|
||||
/// that turn — so the caller can pre-populate the composer with it.
|
||||
/// `None` when no detail was recorded (defensive — every persisted
|
||||
/// `UserMessage` since v0.6 carries a detail string).
|
||||
///
|
||||
/// Counts user turns by iterating `list_turns_for_thread` (sorted
|
||||
/// oldest → newest) backwards. A turn is counted as a "user turn"
|
||||
/// when at least one of its items has `kind ==
|
||||
/// TurnItemKind::UserMessage`. Steered turns (which append additional
|
||||
/// `UserMessage` items) still count as one turn — backtrack rewinds
|
||||
/// at the turn boundary, not at the steer boundary.
|
||||
///
|
||||
/// Errors:
|
||||
/// - `depth_from_tail` exceeds the number of user turns
|
||||
/// - source thread not found
|
||||
#[allow(dead_code)] // exposed for the runtime/HTTP fork-on-backtrack path; the in-TUI Esc-Esc flow trims `App` state directly. Issue #133.
|
||||
pub async fn fork_at_user_message(
|
||||
&self,
|
||||
id: &str,
|
||||
depth_from_tail: usize,
|
||||
) -> Result<(ThreadRecord, Option<String>)> {
|
||||
let source = self.get_thread(id).await?;
|
||||
let source_turns = self.store.list_turns_for_thread(&source.id)?;
|
||||
|
||||
// Walk turns from newest to oldest. For each turn, ask: does it
|
||||
// contain a UserMessage item? If yes, it counts toward the depth.
|
||||
let mut user_turn_indices: Vec<usize> = Vec::new();
|
||||
for (idx, turn) in source_turns.iter().enumerate().rev() {
|
||||
let items = self.store.list_items_for_turn(&turn.id)?;
|
||||
if items
|
||||
.iter()
|
||||
.any(|item| item.kind == TurnItemKind::UserMessage)
|
||||
{
|
||||
user_turn_indices.push(idx);
|
||||
}
|
||||
}
|
||||
if depth_from_tail >= user_turn_indices.len() {
|
||||
bail!(
|
||||
"fork_at_user_message: depth {} exceeds {} user turn(s)",
|
||||
depth_from_tail,
|
||||
user_turn_indices.len()
|
||||
);
|
||||
}
|
||||
// `user_turn_indices` is newest-first because we iterated in
|
||||
// reverse, so the Nth element is exactly the Nth-from-tail user
|
||||
// turn in the original chronological list.
|
||||
let target_turn_idx = user_turn_indices[depth_from_tail];
|
||||
let target_turn_id = source_turns[target_turn_idx].id.clone();
|
||||
|
||||
// Pull the original user-message text out of the dropped turn so
|
||||
// the caller can drop it back into the composer.
|
||||
let target_items = self.store.list_items_for_turn(&target_turn_id)?;
|
||||
let original_user_text = target_items
|
||||
.iter()
|
||||
.find(|item| item.kind == TurnItemKind::UserMessage)
|
||||
.and_then(|item| item.detail.clone());
|
||||
|
||||
// Copy turns strictly before `target_turn_idx` into a new thread.
|
||||
// Mirrors `fork_thread` but stops at the cutoff instead of copying
|
||||
// every turn. Kept structurally close so future parity reviews
|
||||
// can spot drift between the two paths.
|
||||
let mut forked = source.clone();
|
||||
let now = Utc::now();
|
||||
forked.id = format!("thr_{}", &Uuid::new_v4().to_string()[..8]);
|
||||
forked.created_at = now;
|
||||
forked.updated_at = now;
|
||||
forked.latest_turn_id = None;
|
||||
forked.archived = false;
|
||||
self.store.save_thread(&forked)?;
|
||||
|
||||
for source_turn in source_turns.iter().take(target_turn_idx) {
|
||||
let mut cloned_turn = source_turn.clone();
|
||||
cloned_turn.id = format!("turn_{}", &Uuid::new_v4().to_string()[..8]);
|
||||
cloned_turn.thread_id = forked.id.clone();
|
||||
cloned_turn.item_ids.clear();
|
||||
self.store.save_turn(&cloned_turn)?;
|
||||
|
||||
let items = self.store.list_items_for_turn(&source_turn.id)?;
|
||||
for item in items {
|
||||
let mut cloned_item = item.clone();
|
||||
cloned_item.id = format!("item_{}", &Uuid::new_v4().to_string()[..8]);
|
||||
cloned_item.turn_id = cloned_turn.id.clone();
|
||||
self.store.save_item(&cloned_item)?;
|
||||
cloned_turn.item_ids.push(cloned_item.id.clone());
|
||||
}
|
||||
self.store.save_turn(&cloned_turn)?;
|
||||
forked.latest_turn_id = Some(cloned_turn.id.clone());
|
||||
forked.updated_at = now;
|
||||
self.store.save_thread(&forked)?;
|
||||
}
|
||||
|
||||
self.emit_event(
|
||||
&forked.id,
|
||||
None,
|
||||
None,
|
||||
"thread.forked",
|
||||
json!({
|
||||
"thread": forked,
|
||||
"source_thread_id": source.id,
|
||||
"backtrack_depth_from_tail": depth_from_tail,
|
||||
"dropped_turn_id": target_turn_id,
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
Ok((forked, original_user_text))
|
||||
}
|
||||
|
||||
/// Seed a thread with messages from a saved session so subsequent turns
|
||||
/// continue with the prior conversation context.
|
||||
pub async fn seed_thread_from_messages(
|
||||
@@ -3978,4 +4097,191 @@ mod tests {
|
||||
assert_eq!(hints.len(), 1);
|
||||
assert_eq!(hints[0].status, AgentRebindStatus::Completed);
|
||||
}
|
||||
|
||||
/// Helper for the `fork_at_user_message` tests: write a sequence of
|
||||
/// (user, assistant) turns under the given thread id. Each turn gets
|
||||
/// one UserMessage item carrying `user_text` in `detail` plus one
|
||||
/// AgentMessage item. Turn `created_at` is monotonically increasing
|
||||
/// so the chronological sort in `list_turns_for_thread` is stable.
|
||||
fn seed_turns_with_user_messages(
|
||||
manager: &RuntimeThreadManager,
|
||||
thread_id: &str,
|
||||
user_texts: &[&str],
|
||||
) -> Result<Vec<String>> {
|
||||
let mut turn_ids = Vec::new();
|
||||
let base = Utc::now();
|
||||
for (offset, text) in user_texts.iter().enumerate() {
|
||||
let created_at = base + chrono::Duration::milliseconds(offset as i64);
|
||||
let turn_id = format!("turn_test_{offset}");
|
||||
let user_item_id = format!("item_user_{offset}");
|
||||
let asst_item_id = format!("item_asst_{offset}");
|
||||
manager.store.save_item(&TurnItemRecord {
|
||||
schema_version: CURRENT_RUNTIME_SCHEMA_VERSION,
|
||||
id: user_item_id.clone(),
|
||||
turn_id: turn_id.clone(),
|
||||
kind: TurnItemKind::UserMessage,
|
||||
status: TurnItemLifecycleStatus::Completed,
|
||||
summary: (*text).to_string(),
|
||||
detail: Some((*text).to_string()),
|
||||
artifact_refs: Vec::new(),
|
||||
started_at: Some(created_at),
|
||||
ended_at: Some(created_at),
|
||||
})?;
|
||||
manager.store.save_item(&TurnItemRecord {
|
||||
schema_version: CURRENT_RUNTIME_SCHEMA_VERSION,
|
||||
id: asst_item_id.clone(),
|
||||
turn_id: turn_id.clone(),
|
||||
kind: TurnItemKind::AgentMessage,
|
||||
status: TurnItemLifecycleStatus::Completed,
|
||||
summary: format!("reply {offset}"),
|
||||
detail: Some(format!("reply {offset}")),
|
||||
artifact_refs: Vec::new(),
|
||||
started_at: Some(created_at),
|
||||
ended_at: Some(created_at),
|
||||
})?;
|
||||
manager.store.save_turn(&TurnRecord {
|
||||
schema_version: CURRENT_RUNTIME_SCHEMA_VERSION,
|
||||
id: turn_id.clone(),
|
||||
thread_id: thread_id.to_string(),
|
||||
status: RuntimeTurnStatus::Completed,
|
||||
input_summary: (*text).to_string(),
|
||||
created_at,
|
||||
started_at: Some(created_at),
|
||||
ended_at: Some(created_at),
|
||||
duration_ms: Some(0),
|
||||
usage: None,
|
||||
error: None,
|
||||
item_ids: vec![user_item_id, asst_item_id],
|
||||
steer_count: 0,
|
||||
})?;
|
||||
turn_ids.push(turn_id);
|
||||
}
|
||||
Ok(turn_ids)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fork_at_user_message_drops_tail_and_returns_user_text() -> Result<()> {
|
||||
// Seed three completed user/assistant turns. Backtracking with
|
||||
// depth=0 should drop only the most recent turn ("third") and
|
||||
// hand back its original text so the caller can refill the
|
||||
// composer.
|
||||
let manager = test_manager(test_runtime_dir())?;
|
||||
let thread = manager
|
||||
.create_thread(CreateThreadRequest {
|
||||
model: None,
|
||||
workspace: None,
|
||||
mode: None,
|
||||
allow_shell: None,
|
||||
trust_mode: None,
|
||||
auto_approve: None,
|
||||
archived: false,
|
||||
system_prompt: None,
|
||||
})
|
||||
.await?;
|
||||
seed_turns_with_user_messages(&manager, &thread.id, &["first", "second", "third"])?;
|
||||
|
||||
let (forked, original_text) = manager.fork_at_user_message(&thread.id, 0).await?;
|
||||
assert_eq!(original_text.as_deref(), Some("third"));
|
||||
assert_ne!(forked.id, thread.id);
|
||||
|
||||
let forked_turns = manager.store.list_turns_for_thread(&forked.id)?;
|
||||
assert_eq!(
|
||||
forked_turns.len(),
|
||||
2,
|
||||
"depth=0 should drop the most recent turn"
|
||||
);
|
||||
let summaries: Vec<&str> = forked_turns
|
||||
.iter()
|
||||
.map(|t| t.input_summary.as_str())
|
||||
.collect();
|
||||
assert_eq!(summaries, vec!["first", "second"]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fork_at_user_message_depth_one_drops_two_turns() -> Result<()> {
|
||||
let manager = test_manager(test_runtime_dir())?;
|
||||
let thread = manager
|
||||
.create_thread(CreateThreadRequest {
|
||||
model: None,
|
||||
workspace: None,
|
||||
mode: None,
|
||||
allow_shell: None,
|
||||
trust_mode: None,
|
||||
auto_approve: None,
|
||||
archived: false,
|
||||
system_prompt: None,
|
||||
})
|
||||
.await?;
|
||||
seed_turns_with_user_messages(&manager, &thread.id, &["a", "b", "c", "d"])?;
|
||||
|
||||
let (forked, original_text) = manager.fork_at_user_message(&thread.id, 1).await?;
|
||||
assert_eq!(original_text.as_deref(), Some("c"));
|
||||
let forked_turns = manager.store.list_turns_for_thread(&forked.id)?;
|
||||
let summaries: Vec<&str> = forked_turns
|
||||
.iter()
|
||||
.map(|t| t.input_summary.as_str())
|
||||
.collect();
|
||||
assert_eq!(summaries, vec!["a", "b"]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fork_at_user_message_out_of_range_errors() -> Result<()> {
|
||||
let manager = test_manager(test_runtime_dir())?;
|
||||
let thread = manager
|
||||
.create_thread(CreateThreadRequest {
|
||||
model: None,
|
||||
workspace: None,
|
||||
mode: None,
|
||||
allow_shell: None,
|
||||
trust_mode: None,
|
||||
auto_approve: None,
|
||||
archived: false,
|
||||
system_prompt: None,
|
||||
})
|
||||
.await?;
|
||||
seed_turns_with_user_messages(&manager, &thread.id, &["only"])?;
|
||||
|
||||
let err = manager.fork_at_user_message(&thread.id, 5).await.err();
|
||||
assert!(err.is_some(), "depth past the end should bail out");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fork_at_user_message_does_not_mutate_source() -> Result<()> {
|
||||
// The source thread must be untouched: turns still present, items
|
||||
// still present, latest_turn_id still pointing at the original
|
||||
// tail. Backtrack creates a sibling, never edits in place.
|
||||
let manager = test_manager(test_runtime_dir())?;
|
||||
let thread = manager
|
||||
.create_thread(CreateThreadRequest {
|
||||
model: None,
|
||||
workspace: None,
|
||||
mode: None,
|
||||
allow_shell: None,
|
||||
trust_mode: None,
|
||||
auto_approve: None,
|
||||
archived: false,
|
||||
system_prompt: None,
|
||||
})
|
||||
.await?;
|
||||
let turn_ids = seed_turns_with_user_messages(&manager, &thread.id, &["x", "y", "z"])?;
|
||||
|
||||
let _ = manager.fork_at_user_message(&thread.id, 0).await?;
|
||||
|
||||
let source_turns = manager.store.list_turns_for_thread(&thread.id)?;
|
||||
assert_eq!(
|
||||
source_turns.len(),
|
||||
3,
|
||||
"source thread must still hold every turn after fork"
|
||||
);
|
||||
for tid in &turn_ids {
|
||||
assert!(
|
||||
manager.store.load_turn(tid).is_ok(),
|
||||
"turn {tid} must remain on disk"
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user