From 7993f97f8804db5c50899b3d95261ff1d13acf96 Mon Sep 17 00:00:00 2001 From: hqt Date: Thu, 28 May 2026 15:54:58 +0800 Subject: [PATCH] feat(state): add parent_entry_id on the message table for fork support --- crates/core/src/lib.rs | 1 + crates/state/src/lib.rs | 354 ++++++++++++++++++++++------- crates/state/tests/parity_state.rs | 195 ++++++++++++++++ 3 files changed, 468 insertions(+), 82 deletions(-) diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index e6d9f094..472095cc 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -643,6 +643,7 @@ impl ThreadManager { git_branch: None, git_origin_url: None, memory_mode: None, + current_leaf_id: None, }) } } diff --git a/crates/state/src/lib.rs b/crates/state/src/lib.rs index 9bad8a16..7347245d 100644 --- a/crates/state/src/lib.rs +++ b/crates/state/src/lib.rs @@ -53,6 +53,7 @@ pub struct ThreadMetadata { pub git_branch: Option, pub git_origin_url: Option, pub memory_mode: Option, + pub current_leaf_id: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -71,6 +72,7 @@ pub struct MessageRecord { pub content: String, pub item: Option, pub created_at: i64, + pub parent_entry_id: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -162,79 +164,107 @@ impl StateStore { fn init_schema(&self) -> Result<()> { let conn = self.conn()?; - conn.execute_batch( - r#" - CREATE TABLE IF NOT EXISTS threads ( - id TEXT PRIMARY KEY, - rollout_path TEXT, - preview TEXT NOT NULL, - ephemeral INTEGER NOT NULL, - model_provider TEXT NOT NULL, - created_at INTEGER NOT NULL, - updated_at INTEGER NOT NULL, - status TEXT NOT NULL, - path TEXT, - cwd TEXT NOT NULL, - cli_version TEXT NOT NULL, - source TEXT NOT NULL, - title TEXT, - sandbox_policy TEXT, - approval_mode TEXT, - archived INTEGER NOT NULL DEFAULT 0, - archived_at INTEGER, - git_sha TEXT, - git_branch TEXT, - git_origin_url TEXT, - memory_mode TEXT - ); - CREATE INDEX IF NOT EXISTS idx_threads_updated_at ON threads(updated_at DESC); - CREATE INDEX IF NOT EXISTS idx_threads_archived_at ON threads(archived_at DESC); - CREATE INDEX IF NOT EXISTS idx_threads_archived_updated ON threads(archived, updated_at DESC); + let user_version: u32 = conn.query_row("PRAGMA user_version;", [], |row| row.get(0))?; + if user_version == 0 { + conn.execute_batch( + r#" + CREATE TABLE IF NOT EXISTS threads ( + id TEXT PRIMARY KEY, + rollout_path TEXT, + preview TEXT NOT NULL, + ephemeral INTEGER NOT NULL, + model_provider TEXT NOT NULL, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL, + status TEXT NOT NULL, + path TEXT, + cwd TEXT NOT NULL, + cli_version TEXT NOT NULL, + source TEXT NOT NULL, + title TEXT, + sandbox_policy TEXT, + approval_mode TEXT, + archived INTEGER NOT NULL DEFAULT 0, + archived_at INTEGER, + git_sha TEXT, + git_branch TEXT, + git_origin_url TEXT, + memory_mode TEXT + ); + CREATE INDEX IF NOT EXISTS idx_threads_updated_at ON threads(updated_at DESC); + CREATE INDEX IF NOT EXISTS idx_threads_archived_at ON threads(archived_at DESC); + CREATE INDEX IF NOT EXISTS idx_threads_archived_updated ON threads(archived, updated_at DESC); - CREATE TABLE IF NOT EXISTS thread_dynamic_tools ( - thread_id TEXT NOT NULL, - position INTEGER NOT NULL, - name TEXT NOT NULL, - description TEXT, - input_schema TEXT NOT NULL, - PRIMARY KEY (thread_id, position), - FOREIGN KEY(thread_id) REFERENCES threads(id) ON DELETE CASCADE - ); + CREATE TABLE IF NOT EXISTS thread_dynamic_tools ( + thread_id TEXT NOT NULL, + position INTEGER NOT NULL, + name TEXT NOT NULL, + description TEXT, + input_schema TEXT NOT NULL, + PRIMARY KEY (thread_id, position), + FOREIGN KEY(thread_id) REFERENCES threads(id) ON DELETE CASCADE + ); - CREATE TABLE IF NOT EXISTS messages ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - thread_id TEXT NOT NULL, - role TEXT NOT NULL, - content TEXT NOT NULL, - item_json TEXT, - created_at INTEGER NOT NULL, - FOREIGN KEY(thread_id) REFERENCES threads(id) ON DELETE CASCADE - ); - CREATE INDEX IF NOT EXISTS idx_messages_thread_created_at ON messages(thread_id, created_at ASC); + CREATE TABLE IF NOT EXISTS messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + thread_id TEXT NOT NULL, + role TEXT NOT NULL, + content TEXT NOT NULL, + item_json TEXT, + created_at INTEGER NOT NULL, + FOREIGN KEY(thread_id) REFERENCES threads(id) ON DELETE CASCADE + ); + CREATE INDEX IF NOT EXISTS idx_messages_thread_created_at ON messages(thread_id, created_at ASC); - CREATE TABLE IF NOT EXISTS checkpoints ( - thread_id TEXT NOT NULL, - checkpoint_id TEXT NOT NULL, - state_json TEXT NOT NULL, - created_at INTEGER NOT NULL, - PRIMARY KEY(thread_id, checkpoint_id), - FOREIGN KEY(thread_id) REFERENCES threads(id) ON DELETE CASCADE - ); - CREATE INDEX IF NOT EXISTS idx_checkpoints_thread_created_at ON checkpoints(thread_id, created_at DESC); + CREATE TABLE IF NOT EXISTS checkpoints ( + thread_id TEXT NOT NULL, + checkpoint_id TEXT NOT NULL, + state_json TEXT NOT NULL, + created_at INTEGER NOT NULL, + PRIMARY KEY(thread_id, checkpoint_id), + FOREIGN KEY(thread_id) REFERENCES threads(id) ON DELETE CASCADE + ); + CREATE INDEX IF NOT EXISTS idx_checkpoints_thread_created_at ON checkpoints(thread_id, created_at DESC); - CREATE TABLE IF NOT EXISTS jobs ( - id TEXT PRIMARY KEY, - name TEXT NOT NULL, - status TEXT NOT NULL, - progress INTEGER, - detail TEXT, - created_at INTEGER NOT NULL, - updated_at INTEGER NOT NULL - ); - CREATE INDEX IF NOT EXISTS idx_jobs_updated_at ON jobs(updated_at DESC); - "#, - ) - .context("failed to initialize thread schema")?; + CREATE TABLE IF NOT EXISTS jobs ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + status TEXT NOT NULL, + progress INTEGER, + detail TEXT, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_jobs_updated_at ON jobs(updated_at DESC); + + -- Add parent_entry_id column, and set to last message before current message + ALTER TABLE messages ADD COLUMN parent_entry_id INTEGER NULL; + UPDATE messages + SET parent_entry_id = ( + SELECT m2.id + FROM messages m2 + WHERE m2.created_at < messages.created_at AND m2.thread_id = messages.thread_id + ORDER BY m2.created_at DESC + LIMIT 1 + ); + CREATE INDEX idx_messages_parent_entry_id ON messages(parent_entry_id); + + -- Add current_leaf_id column, and set to last message in thread + ALTER TABLE threads ADD COLUMN current_leaf_id INTEGER NULL; + UPDATE threads + SET current_leaf_id = ( + SELECT m.id + FROM messages m + WHERE m.thread_id = threads.id + ORDER BY m.created_at DESC + LIMIT 1 + ); + + PRAGMA user_version = 1; + "#, + ) + .context("failed to initialize thread schema")?; + } Ok(()) } @@ -245,11 +275,11 @@ impl StateStore { INSERT INTO threads ( id, rollout_path, preview, ephemeral, model_provider, created_at, updated_at, status, path, cwd, cli_version, source, title, sandbox_policy, approval_mode, archived, archived_at, - git_sha, git_branch, git_origin_url, memory_mode + git_sha, git_branch, git_origin_url, memory_mode, current_leaf_id ) VALUES ( ?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17, - ?18, ?19, ?20, ?21 + ?18, ?19, ?20, ?21, ?22 ) ON CONFLICT(id) DO UPDATE SET rollout_path=excluded.rollout_path, @@ -271,7 +301,8 @@ impl StateStore { git_sha=excluded.git_sha, git_branch=excluded.git_branch, git_origin_url=excluded.git_origin_url, - memory_mode=excluded.memory_mode + memory_mode=excluded.memory_mode, + current_leaf_id=excluded.current_leaf_id "#, params![ thread.id, @@ -295,6 +326,7 @@ impl StateStore { thread.git_branch, thread.git_origin_url, thread.memory_mode, + thread.current_leaf_id, ], ) .context("failed to upsert thread metadata")?; @@ -314,7 +346,7 @@ impl StateStore { r#" SELECT id, rollout_path, preview, ephemeral, model_provider, created_at, updated_at, status, path, cwd, cli_version, source, title, sandbox_policy, approval_mode, archived, archived_at, - git_sha, git_branch, git_origin_url, memory_mode + git_sha, git_branch, git_origin_url, memory_mode, current_leaf_id FROM threads WHERE id = ?1 "#, @@ -328,9 +360,9 @@ impl StateStore { pub fn list_threads(&self, filters: ThreadListFilters) -> Result> { let conn = self.conn()?; let sql = if filters.include_archived { - "SELECT id, rollout_path, preview, ephemeral, model_provider, created_at, updated_at, status, path, cwd, cli_version, source, title, sandbox_policy, approval_mode, archived, archived_at, git_sha, git_branch, git_origin_url, memory_mode FROM threads ORDER BY updated_at DESC LIMIT ?1" + "SELECT id, rollout_path, preview, ephemeral, model_provider, created_at, updated_at, status, path, cwd, cli_version, source, title, sandbox_policy, approval_mode, archived, archived_at, git_sha, git_branch, git_origin_url, memory_mode, current_leaf_id FROM threads ORDER BY updated_at DESC LIMIT ?1" } else { - "SELECT id, rollout_path, preview, ephemeral, model_provider, created_at, updated_at, status, path, cwd, cli_version, source, title, sandbox_policy, approval_mode, archived, archived_at, git_sha, git_branch, git_origin_url, memory_mode FROM threads WHERE archived = 0 ORDER BY updated_at DESC LIMIT ?1" + "SELECT id, rollout_path, preview, ephemeral, model_provider, created_at, updated_at, status, path, cwd, cli_version, source, title, sandbox_policy, approval_mode, archived, archived_at, git_sha, git_branch, git_origin_url, memory_mode, current_leaf_id FROM threads WHERE archived = 0 ORDER BY updated_at DESC LIMIT ?1" }; let mut stmt = conn.prepare(sql).context("failed to prepare list query")?; @@ -398,6 +430,54 @@ impl StateStore { .map(Option::flatten) } + pub fn list_leaf_messages(&self, thread_id: &str) -> Result> { + let conn = self.conn()?; + let mut stmt = conn + .prepare( + r#" + SELECT m1.id, m1.thread_id, m1.role, m1.content, m1.item_json, m1.created_at, m1.parent_entry_id + FROM messages m1 + LEFT JOIN messages m2 ON m1.id = m2.parent_entry_id + WHERE m1.thread_id = ?1 AND m2.id IS NULL + "#, + ) + .context("failed to prepare message listing query")?; + let mut rows = stmt + .query(params![thread_id]) + .with_context(|| format!("failed to list leaf messages for thread {thread_id}"))?; + let mut out = Vec::new(); + while let Some(row) = rows.next().context("failed to iterate message rows")? { + let item_json: Option = row.get(4).context("failed to read item json")?; + let item = item_json + .as_deref() + .map(serde_json::from_str) + .transpose() + .with_context(|| { + format!("failed to parse message item json in thread {thread_id}") + })?; + out.push(MessageRecord { + id: row.get(0).context("failed to read message id")?, + thread_id: row.get(1).context("failed to read message thread id")?, + role: row.get(2).context("failed to read message role")?, + content: row.get(3).context("failed to read message content")?, + item, + created_at: row.get(5).context("failed to read message timestamp")?, + parent_entry_id: row.get(6).context("failed to read parent entry id")?, + }); + } + Ok(out) + } + + pub fn set_current_leaf_id(&self, thread_id: &str, current_leaf_id: &str) -> Result<()> { + let conn = self.conn()?; + conn.execute( + "UPDATE threads SET current_leaf_id = ?1 WHERE id = ?2", + params![current_leaf_id, thread_id], + ) + .context("failed to update thread current leaf id")?; + Ok(()) + } + pub fn persist_dynamic_tools( &self, thread_id: &str, @@ -464,18 +544,51 @@ impl StateStore { content: &str, item: Option, ) -> Result { - let conn = self.conn()?; + let mut conn = self.conn()?; let created_at = Utc::now().timestamp(); let item_json = item .as_ref() .map(serde_json::to_string) .transpose() .context("failed to serialize message item payload")?; - conn.execute( - "INSERT INTO messages(thread_id, role, content, item_json, created_at) VALUES (?1, ?2, ?3, ?4, ?5)", - params![thread_id, role, content, item_json, created_at], + + let tx = conn + .transaction() + .context("failed to begin append message transaction")?; + + let current_leaf_id: Option = tx + .query_row( + "SELECT current_leaf_id FROM threads WHERE id = ?1", + params![thread_id], + |row| row.get(0), + ) + .with_context(|| { + format!("failed to query thread current leaf id for thread {thread_id}") + })?; + + let next_leaf_id: i64 = tx.query_row( + r#" + INSERT INTO messages(thread_id, role, content, item_json, created_at, parent_entry_id) + SELECT ?1, ?2, ?3, ?4, ?5, ?6 + RETURNING id + "#, params![thread_id, role, content, item_json, created_at, current_leaf_id], |row| row.get(0) + ).with_context(|| format!("failed to append message for thread {thread_id}"))?; + + tx.execute( + r#" + UPDATE threads + SET current_leaf_id = ?1 + WHERE id = ?2; + "#, + params![next_leaf_id, thread_id], ) - .with_context(|| format!("failed to append message for thread {thread_id}"))?; + .with_context(|| { + format!("failed to update thread current leaf id for thread {thread_id}") + })?; + + tx.commit() + .context("failed to commit append message transaction")?; + Ok(conn.last_insert_rowid()) } @@ -488,11 +601,30 @@ impl StateStore { let limit = i64::try_from(limit.unwrap_or(500)).unwrap_or(500); let mut stmt = conn .prepare( - "SELECT id, thread_id, role, content, item_json, created_at FROM messages WHERE thread_id = ?1 ORDER BY created_at ASC LIMIT ?2", + r#" + WITH RECURSIVE + leaf_id AS ( + SELECT current_leaf_id FROM threads WHERE id = ?1 + ), + ancestors AS ( + SELECT id, thread_id, role, content, item_json, created_at, parent_entry_id, 0 AS depth + FROM messages + WHERE id = (SELECT current_leaf_id FROM leaf_id) + + UNION ALL + + SELECT m.id, m.thread_id, m.role, m.content, m.item_json, m.created_at, m.parent_entry_id, a.depth + 1 + FROM messages m + JOIN ancestors a ON m.id = a.parent_entry_id + WHERE a.depth < ?2 + ) + SELECT id, thread_id, role, content, item_json, created_at, parent_entry_id FROM ancestors + ORDER BY depth ASC + "# ) .context("failed to prepare message listing query")?; let mut rows = stmt - .query(params![thread_id, limit]) + .query(params![thread_id, limit - 1]) .with_context(|| format!("failed to list messages for thread {thread_id}"))?; let mut out = Vec::new(); while let Some(row) = rows.next().context("failed to iterate message rows")? { @@ -511,11 +643,68 @@ impl StateStore { content: row.get(3).context("failed to read message content")?, item, created_at: row.get(5).context("failed to read message timestamp")?, + parent_entry_id: row.get(6).context("failed to read parent entry id")?, }); } Ok(out) } + pub fn fork_at_message( + &self, + message_id: &str, + role: &str, + content: &str, + item: Option, + ) -> Result { + let mut conn = self.conn()?; + let created_at = Utc::now().timestamp(); + let item_json = item + .as_ref() + .map(serde_json::to_string) + .transpose() + .context("failed to serialize message item payload")?; + + let tx = conn + .transaction() + .context("failed to begin fork message transaction")?; + + let thread_id: Option = tx + .query_row( + "SELECT thread_id FROM messages WHERE id = ?1", + params![message_id], + |row| row.get(0), + ) + .with_context(|| format!("failed to query thread id for message {message_id}"))?; + + let next_leaf_id: i64 = tx.query_row( + r#" + INSERT INTO messages(thread_id, role, content, item_json, created_at, parent_entry_id) + SELECT ?1, ?2, ?3, ?4, ?5, ?6 + RETURNING id + "#, params![thread_id, role, content, item_json, created_at, message_id], |row| row.get(0) + ).with_context(|| format!("failed to fork at message for thread {:?}", thread_id))?; + + tx.execute( + r#" + UPDATE threads + SET current_leaf_id = ?1 + WHERE id = ?2; + "#, + params![next_leaf_id, thread_id], + ) + .with_context(|| { + format!( + "failed to update thread current leaf id for thread {:?}", + thread_id + ) + })?; + + tx.commit() + .context("failed to commit fork message transaction")?; + + Ok(next_leaf_id) + } + pub fn clear_messages(&self, thread_id: &str) -> Result { let conn = self.conn()?; conn.execute( @@ -946,5 +1135,6 @@ fn row_to_thread(row: &rusqlite::Row<'_>) -> rusqlite::Result { git_branch: row.get(18)?, git_origin_url: row.get(19)?, memory_mode: row.get(20)?, + current_leaf_id: row.get(21)?, }) } diff --git a/crates/state/tests/parity_state.rs b/crates/state/tests/parity_state.rs index d666f50b..cee9192a 100644 --- a/crates/state/tests/parity_state.rs +++ b/crates/state/tests/parity_state.rs @@ -1,6 +1,7 @@ use std::path::PathBuf; use codewhale_state::{SessionSource, StateStore, ThreadListFilters, ThreadMetadata, ThreadStatus}; +use rusqlite::Connection; fn temp_state_path(label: &str) -> PathBuf { std::env::temp_dir().join(format!( @@ -38,6 +39,7 @@ fn upsert_and_resume_thread_metadata() { git_branch: None, git_origin_url: None, memory_mode: Some("extended".to_string()), + current_leaf_id: None, }; store.upsert_thread(&thread).expect("upsert thread"); @@ -70,3 +72,196 @@ fn upsert_and_resume_thread_metadata() { .expect("list threads"); assert!(!listed.is_empty()); } + +#[test] +fn init_schema_migration() { + let path = temp_state_path("init_schema_migration"); + let conn = Connection::open(&path).expect("open state db"); + conn.execute_batch( + r#" + CREATE TABLE IF NOT EXISTS threads ( + id TEXT PRIMARY KEY, + rollout_path TEXT, + preview TEXT NOT NULL, + ephemeral INTEGER NOT NULL, + model_provider TEXT NOT NULL, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL, + status TEXT NOT NULL, + path TEXT, + cwd TEXT NOT NULL, + cli_version TEXT NOT NULL, + source TEXT NOT NULL, + title TEXT, + sandbox_policy TEXT, + approval_mode TEXT, + archived INTEGER NOT NULL DEFAULT 0, + archived_at INTEGER, + git_sha TEXT, + git_branch TEXT, + git_origin_url TEXT, + memory_mode TEXT + ); + CREATE TABLE IF NOT EXISTS messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + thread_id TEXT NOT NULL, + role TEXT NOT NULL, + content TEXT NOT NULL, + item_json TEXT, + created_at INTEGER NOT NULL, + FOREIGN KEY(thread_id) REFERENCES threads(id) ON DELETE CASCADE + ); + INSERT INTO threads ( + id, preview, ephemeral, model_provider, created_at, updated_at, status, cwd, cli_version, source, archived + ) + VALUES ( + 'thread-test-1', 'hello', false, 'deepseek', 0, 0, 'running', '/tmp/project', '0.0.0-test', 'interactive', false + ); + INSERT INTO messages (thread_id, role, content, created_at) VALUES + ('thread-test-1', 'foo0', 'bar0', 0), + ('thread-test-1', 'foo1', 'bar1', 1), + ('thread-test-1', 'foo2', 'bar2', 2); + "#, + ) + .expect("init schema migration"); + + let store = StateStore::open(Some(path.clone())).expect("open state store"); + let thread = store + .get_thread("thread-test-1") + .expect("read thread") + .unwrap(); + assert_eq!(thread.id, "thread-test-1"); + assert_eq!(thread.preview, "hello"); + assert!(!thread.ephemeral); + assert_eq!(thread.model_provider, "deepseek"); + assert_eq!(thread.created_at, 0); + assert_eq!(thread.updated_at, 0); + assert_eq!(thread.status, ThreadStatus::Running); + assert_eq!(thread.cwd, PathBuf::from("/tmp/project")); + assert_eq!(thread.cli_version, "0.0.0-test"); + assert_eq!(thread.source, SessionSource::Interactive); + assert!(thread.current_leaf_id.is_some()); + + let messages = store + .list_messages("thread-test-1", None) + .expect("list messages"); + assert_eq!(messages.len(), 3); + for (i, message) in messages.iter().enumerate() { + assert_eq!(message.thread_id, "thread-test-1"); + assert_eq!(message.role, format!("foo{}", 2 - i)); + assert_eq!(message.content, format!("bar{}", 2 - i)); + assert_eq!(message.created_at, 2 - i as i64); + } + + // Test idempotent + StateStore::open(Some(path.clone())).expect("open state store"); +} + +#[test] +fn test_fork() { + let path = temp_state_path("test_fork"); + let store = StateStore::open(Some(path.clone())).expect("open state store"); + let now = chrono::Utc::now().timestamp(); + let thread = ThreadMetadata { + id: "thread-test-1".to_string(), + rollout_path: Some(PathBuf::from("/tmp/rollout.jsonl")), + preview: "hello".to_string(), + ephemeral: false, + model_provider: "deepseek".to_string(), + created_at: now, + updated_at: now, + status: ThreadStatus::Running, + path: Some(PathBuf::from("/tmp/project")), + cwd: PathBuf::from("/tmp/project"), + cli_version: "0.0.0-test".to_string(), + source: SessionSource::Interactive, + name: Some("Test Thread".to_string()), + sandbox_policy: Some("workspace-write".to_string()), + approval_mode: Some("on-request".to_string()), + archived: false, + archived_at: None, + git_sha: None, + git_branch: None, + git_origin_url: None, + memory_mode: Some("extended".to_string()), + current_leaf_id: None, + }; + + store.upsert_thread(&thread).expect("upsert thread"); + store + .append_message("thread-test-1", "foo0", "bar0", None) + .expect("append message"); + store + .append_message("thread-test-1", "foo1", "bar1", None) + .expect("append message"); + store + .append_message("thread-test-1", "foo2", "bar2", None) + .expect("append message"); + store + .append_message("thread-test-1", "foo3", "bar3", None) + .expect("append message"); + store + .append_message("thread-test-1", "foo4", "bar4", None) + .expect("append message"); + + let messages = store + .list_messages("thread-test-1", None) + .expect("list messages"); + assert_eq!(messages.len(), 5); + let ids = messages + .iter() + .enumerate() + .map(|(i, message)| { + assert_eq!(message.thread_id, "thread-test-1"); + assert_eq!(message.role, format!("foo{}", 4 - i)); + assert_eq!(message.content, format!("bar{}", 4 - i)); + message.id.to_string() + }) + .collect::>(); + + store + .fork_at_message(&ids[2], "foo5", "bar5", None) + .expect("fork at message"); + let messages = store + .list_messages("thread-test-1", None) + .expect("list messages"); + assert_eq!(messages.len(), 4); + const LIST_1: [i64; 4] = [5, 2, 1, 0]; + messages + .iter() + .zip(LIST_1.iter()) + .for_each(|(message, &i)| { + assert_eq!(message.thread_id, "thread-test-1"); + assert_eq!(message.role, format!("foo{}", i)); + assert_eq!(message.content, format!("bar{}", i)); + }); + let leaves = store + .list_leaf_messages("thread-test-1") + .expect("list leaf messages"); + assert_eq!(leaves.len(), 2); + + store + .set_current_leaf_id("thread-test-1", &ids[0]) + .expect("set current leaf id"); + store + .append_message("thread-test-1", "foo6", "bar6", None) + .expect("append message"); + let messages = store + .list_messages("thread-test-1", None) + .expect("list messages"); + assert_eq!(messages.len(), 6); + const LIST_2: [i64; 6] = [6, 4, 3, 2, 1, 0]; + messages + .iter() + .zip(LIST_2.iter()) + .for_each(|(message, &i)| { + assert_eq!(message.thread_id, "thread-test-1"); + assert_eq!(message.role, format!("foo{}", i)); + assert_eq!(message.content, format!("bar{}", i)); + }); + + let leaves = store + .list_leaf_messages("thread-test-1") + .expect("list leaf messages"); + assert_eq!(leaves.len(), 2); +}