feat(state): add parent_entry_id on the message table for fork support
This commit is contained in:
@@ -643,6 +643,7 @@ impl ThreadManager {
|
||||
git_branch: None,
|
||||
git_origin_url: None,
|
||||
memory_mode: None,
|
||||
current_leaf_id: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+272
-82
@@ -53,6 +53,7 @@ pub struct ThreadMetadata {
|
||||
pub git_branch: Option<String>,
|
||||
pub git_origin_url: Option<String>,
|
||||
pub memory_mode: Option<String>,
|
||||
pub current_leaf_id: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -71,6 +72,7 @@ pub struct MessageRecord {
|
||||
pub content: String,
|
||||
pub item: Option<Value>,
|
||||
pub created_at: i64,
|
||||
pub parent_entry_id: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -162,79 +164,107 @@ impl StateStore {
|
||||
|
||||
fn init_schema(&self) -> Result<()> {
|
||||
let conn = self.conn()?;
|
||||
conn.execute_batch(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS threads (
|
||||
id TEXT PRIMARY KEY,
|
||||
rollout_path TEXT,
|
||||
preview TEXT NOT NULL,
|
||||
ephemeral INTEGER NOT NULL,
|
||||
model_provider TEXT NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
updated_at INTEGER NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
path TEXT,
|
||||
cwd TEXT NOT NULL,
|
||||
cli_version TEXT NOT NULL,
|
||||
source TEXT NOT NULL,
|
||||
title TEXT,
|
||||
sandbox_policy TEXT,
|
||||
approval_mode TEXT,
|
||||
archived INTEGER NOT NULL DEFAULT 0,
|
||||
archived_at INTEGER,
|
||||
git_sha TEXT,
|
||||
git_branch TEXT,
|
||||
git_origin_url TEXT,
|
||||
memory_mode TEXT
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_threads_updated_at ON threads(updated_at DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_threads_archived_at ON threads(archived_at DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_threads_archived_updated ON threads(archived, updated_at DESC);
|
||||
let user_version: u32 = conn.query_row("PRAGMA user_version;", [], |row| row.get(0))?;
|
||||
if user_version == 0 {
|
||||
conn.execute_batch(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS threads (
|
||||
id TEXT PRIMARY KEY,
|
||||
rollout_path TEXT,
|
||||
preview TEXT NOT NULL,
|
||||
ephemeral INTEGER NOT NULL,
|
||||
model_provider TEXT NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
updated_at INTEGER NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
path TEXT,
|
||||
cwd TEXT NOT NULL,
|
||||
cli_version TEXT NOT NULL,
|
||||
source TEXT NOT NULL,
|
||||
title TEXT,
|
||||
sandbox_policy TEXT,
|
||||
approval_mode TEXT,
|
||||
archived INTEGER NOT NULL DEFAULT 0,
|
||||
archived_at INTEGER,
|
||||
git_sha TEXT,
|
||||
git_branch TEXT,
|
||||
git_origin_url TEXT,
|
||||
memory_mode TEXT
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_threads_updated_at ON threads(updated_at DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_threads_archived_at ON threads(archived_at DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_threads_archived_updated ON threads(archived, updated_at DESC);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS thread_dynamic_tools (
|
||||
thread_id TEXT NOT NULL,
|
||||
position INTEGER NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
input_schema TEXT NOT NULL,
|
||||
PRIMARY KEY (thread_id, position),
|
||||
FOREIGN KEY(thread_id) REFERENCES threads(id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS thread_dynamic_tools (
|
||||
thread_id TEXT NOT NULL,
|
||||
position INTEGER NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
input_schema TEXT NOT NULL,
|
||||
PRIMARY KEY (thread_id, position),
|
||||
FOREIGN KEY(thread_id) REFERENCES threads(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
thread_id TEXT NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
item_json TEXT,
|
||||
created_at INTEGER NOT NULL,
|
||||
FOREIGN KEY(thread_id) REFERENCES threads(id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_thread_created_at ON messages(thread_id, created_at ASC);
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
thread_id TEXT NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
item_json TEXT,
|
||||
created_at INTEGER NOT NULL,
|
||||
FOREIGN KEY(thread_id) REFERENCES threads(id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_thread_created_at ON messages(thread_id, created_at ASC);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS checkpoints (
|
||||
thread_id TEXT NOT NULL,
|
||||
checkpoint_id TEXT NOT NULL,
|
||||
state_json TEXT NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
PRIMARY KEY(thread_id, checkpoint_id),
|
||||
FOREIGN KEY(thread_id) REFERENCES threads(id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_checkpoints_thread_created_at ON checkpoints(thread_id, created_at DESC);
|
||||
CREATE TABLE IF NOT EXISTS checkpoints (
|
||||
thread_id TEXT NOT NULL,
|
||||
checkpoint_id TEXT NOT NULL,
|
||||
state_json TEXT NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
PRIMARY KEY(thread_id, checkpoint_id),
|
||||
FOREIGN KEY(thread_id) REFERENCES threads(id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_checkpoints_thread_created_at ON checkpoints(thread_id, created_at DESC);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS jobs (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
progress INTEGER,
|
||||
detail TEXT,
|
||||
created_at INTEGER NOT NULL,
|
||||
updated_at INTEGER NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_jobs_updated_at ON jobs(updated_at DESC);
|
||||
"#,
|
||||
)
|
||||
.context("failed to initialize thread schema")?;
|
||||
CREATE TABLE IF NOT EXISTS jobs (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
progress INTEGER,
|
||||
detail TEXT,
|
||||
created_at INTEGER NOT NULL,
|
||||
updated_at INTEGER NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_jobs_updated_at ON jobs(updated_at DESC);
|
||||
|
||||
-- Add parent_entry_id column, and set to last message before current message
|
||||
ALTER TABLE messages ADD COLUMN parent_entry_id INTEGER NULL;
|
||||
UPDATE messages
|
||||
SET parent_entry_id = (
|
||||
SELECT m2.id
|
||||
FROM messages m2
|
||||
WHERE m2.created_at < messages.created_at AND m2.thread_id = messages.thread_id
|
||||
ORDER BY m2.created_at DESC
|
||||
LIMIT 1
|
||||
);
|
||||
CREATE INDEX idx_messages_parent_entry_id ON messages(parent_entry_id);
|
||||
|
||||
-- Add current_leaf_id column, and set to last message in thread
|
||||
ALTER TABLE threads ADD COLUMN current_leaf_id INTEGER NULL;
|
||||
UPDATE threads
|
||||
SET current_leaf_id = (
|
||||
SELECT m.id
|
||||
FROM messages m
|
||||
WHERE m.thread_id = threads.id
|
||||
ORDER BY m.created_at DESC
|
||||
LIMIT 1
|
||||
);
|
||||
|
||||
PRAGMA user_version = 1;
|
||||
"#,
|
||||
)
|
||||
.context("failed to initialize thread schema")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -245,11 +275,11 @@ impl StateStore {
|
||||
INSERT INTO threads (
|
||||
id, rollout_path, preview, ephemeral, model_provider, created_at, updated_at, status, path, cwd,
|
||||
cli_version, source, title, sandbox_policy, approval_mode, archived, archived_at,
|
||||
git_sha, git_branch, git_origin_url, memory_mode
|
||||
git_sha, git_branch, git_origin_url, memory_mode, current_leaf_id
|
||||
) VALUES (
|
||||
?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10,
|
||||
?11, ?12, ?13, ?14, ?15, ?16, ?17,
|
||||
?18, ?19, ?20, ?21
|
||||
?18, ?19, ?20, ?21, ?22
|
||||
)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
rollout_path=excluded.rollout_path,
|
||||
@@ -271,7 +301,8 @@ impl StateStore {
|
||||
git_sha=excluded.git_sha,
|
||||
git_branch=excluded.git_branch,
|
||||
git_origin_url=excluded.git_origin_url,
|
||||
memory_mode=excluded.memory_mode
|
||||
memory_mode=excluded.memory_mode,
|
||||
current_leaf_id=excluded.current_leaf_id
|
||||
"#,
|
||||
params![
|
||||
thread.id,
|
||||
@@ -295,6 +326,7 @@ impl StateStore {
|
||||
thread.git_branch,
|
||||
thread.git_origin_url,
|
||||
thread.memory_mode,
|
||||
thread.current_leaf_id,
|
||||
],
|
||||
)
|
||||
.context("failed to upsert thread metadata")?;
|
||||
@@ -314,7 +346,7 @@ impl StateStore {
|
||||
r#"
|
||||
SELECT id, rollout_path, preview, ephemeral, model_provider, created_at, updated_at, status, path, cwd,
|
||||
cli_version, source, title, sandbox_policy, approval_mode, archived, archived_at,
|
||||
git_sha, git_branch, git_origin_url, memory_mode
|
||||
git_sha, git_branch, git_origin_url, memory_mode, current_leaf_id
|
||||
FROM threads
|
||||
WHERE id = ?1
|
||||
"#,
|
||||
@@ -328,9 +360,9 @@ impl StateStore {
|
||||
pub fn list_threads(&self, filters: ThreadListFilters) -> Result<Vec<ThreadMetadata>> {
|
||||
let conn = self.conn()?;
|
||||
let sql = if filters.include_archived {
|
||||
"SELECT id, rollout_path, preview, ephemeral, model_provider, created_at, updated_at, status, path, cwd, cli_version, source, title, sandbox_policy, approval_mode, archived, archived_at, git_sha, git_branch, git_origin_url, memory_mode FROM threads ORDER BY updated_at DESC LIMIT ?1"
|
||||
"SELECT id, rollout_path, preview, ephemeral, model_provider, created_at, updated_at, status, path, cwd, cli_version, source, title, sandbox_policy, approval_mode, archived, archived_at, git_sha, git_branch, git_origin_url, memory_mode, current_leaf_id FROM threads ORDER BY updated_at DESC LIMIT ?1"
|
||||
} else {
|
||||
"SELECT id, rollout_path, preview, ephemeral, model_provider, created_at, updated_at, status, path, cwd, cli_version, source, title, sandbox_policy, approval_mode, archived, archived_at, git_sha, git_branch, git_origin_url, memory_mode FROM threads WHERE archived = 0 ORDER BY updated_at DESC LIMIT ?1"
|
||||
"SELECT id, rollout_path, preview, ephemeral, model_provider, created_at, updated_at, status, path, cwd, cli_version, source, title, sandbox_policy, approval_mode, archived, archived_at, git_sha, git_branch, git_origin_url, memory_mode, current_leaf_id FROM threads WHERE archived = 0 ORDER BY updated_at DESC LIMIT ?1"
|
||||
};
|
||||
|
||||
let mut stmt = conn.prepare(sql).context("failed to prepare list query")?;
|
||||
@@ -398,6 +430,54 @@ impl StateStore {
|
||||
.map(Option::flatten)
|
||||
}
|
||||
|
||||
pub fn list_leaf_messages(&self, thread_id: &str) -> Result<Vec<MessageRecord>> {
|
||||
let conn = self.conn()?;
|
||||
let mut stmt = conn
|
||||
.prepare(
|
||||
r#"
|
||||
SELECT m1.id, m1.thread_id, m1.role, m1.content, m1.item_json, m1.created_at, m1.parent_entry_id
|
||||
FROM messages m1
|
||||
LEFT JOIN messages m2 ON m1.id = m2.parent_entry_id
|
||||
WHERE m1.thread_id = ?1 AND m2.id IS NULL
|
||||
"#,
|
||||
)
|
||||
.context("failed to prepare message listing query")?;
|
||||
let mut rows = stmt
|
||||
.query(params![thread_id])
|
||||
.with_context(|| format!("failed to list leaf messages for thread {thread_id}"))?;
|
||||
let mut out = Vec::new();
|
||||
while let Some(row) = rows.next().context("failed to iterate message rows")? {
|
||||
let item_json: Option<String> = row.get(4).context("failed to read item json")?;
|
||||
let item = item_json
|
||||
.as_deref()
|
||||
.map(serde_json::from_str)
|
||||
.transpose()
|
||||
.with_context(|| {
|
||||
format!("failed to parse message item json in thread {thread_id}")
|
||||
})?;
|
||||
out.push(MessageRecord {
|
||||
id: row.get(0).context("failed to read message id")?,
|
||||
thread_id: row.get(1).context("failed to read message thread id")?,
|
||||
role: row.get(2).context("failed to read message role")?,
|
||||
content: row.get(3).context("failed to read message content")?,
|
||||
item,
|
||||
created_at: row.get(5).context("failed to read message timestamp")?,
|
||||
parent_entry_id: row.get(6).context("failed to read parent entry id")?,
|
||||
});
|
||||
}
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
pub fn set_current_leaf_id(&self, thread_id: &str, current_leaf_id: &str) -> Result<()> {
|
||||
let conn = self.conn()?;
|
||||
conn.execute(
|
||||
"UPDATE threads SET current_leaf_id = ?1 WHERE id = ?2",
|
||||
params![current_leaf_id, thread_id],
|
||||
)
|
||||
.context("failed to update thread current leaf id")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn persist_dynamic_tools(
|
||||
&self,
|
||||
thread_id: &str,
|
||||
@@ -464,18 +544,51 @@ impl StateStore {
|
||||
content: &str,
|
||||
item: Option<Value>,
|
||||
) -> Result<i64> {
|
||||
let conn = self.conn()?;
|
||||
let mut conn = self.conn()?;
|
||||
let created_at = Utc::now().timestamp();
|
||||
let item_json = item
|
||||
.as_ref()
|
||||
.map(serde_json::to_string)
|
||||
.transpose()
|
||||
.context("failed to serialize message item payload")?;
|
||||
conn.execute(
|
||||
"INSERT INTO messages(thread_id, role, content, item_json, created_at) VALUES (?1, ?2, ?3, ?4, ?5)",
|
||||
params![thread_id, role, content, item_json, created_at],
|
||||
|
||||
let tx = conn
|
||||
.transaction()
|
||||
.context("failed to begin append message transaction")?;
|
||||
|
||||
let current_leaf_id: Option<i64> = tx
|
||||
.query_row(
|
||||
"SELECT current_leaf_id FROM threads WHERE id = ?1",
|
||||
params![thread_id],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.with_context(|| {
|
||||
format!("failed to query thread current leaf id for thread {thread_id}")
|
||||
})?;
|
||||
|
||||
let next_leaf_id: i64 = tx.query_row(
|
||||
r#"
|
||||
INSERT INTO messages(thread_id, role, content, item_json, created_at, parent_entry_id)
|
||||
SELECT ?1, ?2, ?3, ?4, ?5, ?6
|
||||
RETURNING id
|
||||
"#, params![thread_id, role, content, item_json, created_at, current_leaf_id], |row| row.get(0)
|
||||
).with_context(|| format!("failed to append message for thread {thread_id}"))?;
|
||||
|
||||
tx.execute(
|
||||
r#"
|
||||
UPDATE threads
|
||||
SET current_leaf_id = ?1
|
||||
WHERE id = ?2;
|
||||
"#,
|
||||
params![next_leaf_id, thread_id],
|
||||
)
|
||||
.with_context(|| format!("failed to append message for thread {thread_id}"))?;
|
||||
.with_context(|| {
|
||||
format!("failed to update thread current leaf id for thread {thread_id}")
|
||||
})?;
|
||||
|
||||
tx.commit()
|
||||
.context("failed to commit append message transaction")?;
|
||||
|
||||
Ok(conn.last_insert_rowid())
|
||||
}
|
||||
|
||||
@@ -488,11 +601,30 @@ impl StateStore {
|
||||
let limit = i64::try_from(limit.unwrap_or(500)).unwrap_or(500);
|
||||
let mut stmt = conn
|
||||
.prepare(
|
||||
"SELECT id, thread_id, role, content, item_json, created_at FROM messages WHERE thread_id = ?1 ORDER BY created_at ASC LIMIT ?2",
|
||||
r#"
|
||||
WITH RECURSIVE
|
||||
leaf_id AS (
|
||||
SELECT current_leaf_id FROM threads WHERE id = ?1
|
||||
),
|
||||
ancestors AS (
|
||||
SELECT id, thread_id, role, content, item_json, created_at, parent_entry_id, 0 AS depth
|
||||
FROM messages
|
||||
WHERE id = (SELECT current_leaf_id FROM leaf_id)
|
||||
|
||||
UNION ALL
|
||||
|
||||
SELECT m.id, m.thread_id, m.role, m.content, m.item_json, m.created_at, m.parent_entry_id, a.depth + 1
|
||||
FROM messages m
|
||||
JOIN ancestors a ON m.id = a.parent_entry_id
|
||||
WHERE a.depth < ?2
|
||||
)
|
||||
SELECT id, thread_id, role, content, item_json, created_at, parent_entry_id FROM ancestors
|
||||
ORDER BY depth ASC
|
||||
"#
|
||||
)
|
||||
.context("failed to prepare message listing query")?;
|
||||
let mut rows = stmt
|
||||
.query(params![thread_id, limit])
|
||||
.query(params![thread_id, limit - 1])
|
||||
.with_context(|| format!("failed to list messages for thread {thread_id}"))?;
|
||||
let mut out = Vec::new();
|
||||
while let Some(row) = rows.next().context("failed to iterate message rows")? {
|
||||
@@ -511,11 +643,68 @@ impl StateStore {
|
||||
content: row.get(3).context("failed to read message content")?,
|
||||
item,
|
||||
created_at: row.get(5).context("failed to read message timestamp")?,
|
||||
parent_entry_id: row.get(6).context("failed to read parent entry id")?,
|
||||
});
|
||||
}
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
pub fn fork_at_message(
|
||||
&self,
|
||||
message_id: &str,
|
||||
role: &str,
|
||||
content: &str,
|
||||
item: Option<Value>,
|
||||
) -> Result<i64> {
|
||||
let mut conn = self.conn()?;
|
||||
let created_at = Utc::now().timestamp();
|
||||
let item_json = item
|
||||
.as_ref()
|
||||
.map(serde_json::to_string)
|
||||
.transpose()
|
||||
.context("failed to serialize message item payload")?;
|
||||
|
||||
let tx = conn
|
||||
.transaction()
|
||||
.context("failed to begin fork message transaction")?;
|
||||
|
||||
let thread_id: Option<String> = tx
|
||||
.query_row(
|
||||
"SELECT thread_id FROM messages WHERE id = ?1",
|
||||
params![message_id],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.with_context(|| format!("failed to query thread id for message {message_id}"))?;
|
||||
|
||||
let next_leaf_id: i64 = tx.query_row(
|
||||
r#"
|
||||
INSERT INTO messages(thread_id, role, content, item_json, created_at, parent_entry_id)
|
||||
SELECT ?1, ?2, ?3, ?4, ?5, ?6
|
||||
RETURNING id
|
||||
"#, params![thread_id, role, content, item_json, created_at, message_id], |row| row.get(0)
|
||||
).with_context(|| format!("failed to fork at message for thread {:?}", thread_id))?;
|
||||
|
||||
tx.execute(
|
||||
r#"
|
||||
UPDATE threads
|
||||
SET current_leaf_id = ?1
|
||||
WHERE id = ?2;
|
||||
"#,
|
||||
params![next_leaf_id, thread_id],
|
||||
)
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"failed to update thread current leaf id for thread {:?}",
|
||||
thread_id
|
||||
)
|
||||
})?;
|
||||
|
||||
tx.commit()
|
||||
.context("failed to commit fork message transaction")?;
|
||||
|
||||
Ok(next_leaf_id)
|
||||
}
|
||||
|
||||
pub fn clear_messages(&self, thread_id: &str) -> Result<usize> {
|
||||
let conn = self.conn()?;
|
||||
conn.execute(
|
||||
@@ -946,5 +1135,6 @@ fn row_to_thread(row: &rusqlite::Row<'_>) -> rusqlite::Result<ThreadMetadata> {
|
||||
git_branch: row.get(18)?,
|
||||
git_origin_url: row.get(19)?,
|
||||
memory_mode: row.get(20)?,
|
||||
current_leaf_id: row.get(21)?,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use codewhale_state::{SessionSource, StateStore, ThreadListFilters, ThreadMetadata, ThreadStatus};
|
||||
use rusqlite::Connection;
|
||||
|
||||
fn temp_state_path(label: &str) -> PathBuf {
|
||||
std::env::temp_dir().join(format!(
|
||||
@@ -38,6 +39,7 @@ fn upsert_and_resume_thread_metadata() {
|
||||
git_branch: None,
|
||||
git_origin_url: None,
|
||||
memory_mode: Some("extended".to_string()),
|
||||
current_leaf_id: None,
|
||||
};
|
||||
store.upsert_thread(&thread).expect("upsert thread");
|
||||
|
||||
@@ -70,3 +72,196 @@ fn upsert_and_resume_thread_metadata() {
|
||||
.expect("list threads");
|
||||
assert!(!listed.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn init_schema_migration() {
|
||||
let path = temp_state_path("init_schema_migration");
|
||||
let conn = Connection::open(&path).expect("open state db");
|
||||
conn.execute_batch(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS threads (
|
||||
id TEXT PRIMARY KEY,
|
||||
rollout_path TEXT,
|
||||
preview TEXT NOT NULL,
|
||||
ephemeral INTEGER NOT NULL,
|
||||
model_provider TEXT NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
updated_at INTEGER NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
path TEXT,
|
||||
cwd TEXT NOT NULL,
|
||||
cli_version TEXT NOT NULL,
|
||||
source TEXT NOT NULL,
|
||||
title TEXT,
|
||||
sandbox_policy TEXT,
|
||||
approval_mode TEXT,
|
||||
archived INTEGER NOT NULL DEFAULT 0,
|
||||
archived_at INTEGER,
|
||||
git_sha TEXT,
|
||||
git_branch TEXT,
|
||||
git_origin_url TEXT,
|
||||
memory_mode TEXT
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
thread_id TEXT NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
item_json TEXT,
|
||||
created_at INTEGER NOT NULL,
|
||||
FOREIGN KEY(thread_id) REFERENCES threads(id) ON DELETE CASCADE
|
||||
);
|
||||
INSERT INTO threads (
|
||||
id, preview, ephemeral, model_provider, created_at, updated_at, status, cwd, cli_version, source, archived
|
||||
)
|
||||
VALUES (
|
||||
'thread-test-1', 'hello', false, 'deepseek', 0, 0, 'running', '/tmp/project', '0.0.0-test', 'interactive', false
|
||||
);
|
||||
INSERT INTO messages (thread_id, role, content, created_at) VALUES
|
||||
('thread-test-1', 'foo0', 'bar0', 0),
|
||||
('thread-test-1', 'foo1', 'bar1', 1),
|
||||
('thread-test-1', 'foo2', 'bar2', 2);
|
||||
"#,
|
||||
)
|
||||
.expect("init schema migration");
|
||||
|
||||
let store = StateStore::open(Some(path.clone())).expect("open state store");
|
||||
let thread = store
|
||||
.get_thread("thread-test-1")
|
||||
.expect("read thread")
|
||||
.unwrap();
|
||||
assert_eq!(thread.id, "thread-test-1");
|
||||
assert_eq!(thread.preview, "hello");
|
||||
assert!(!thread.ephemeral);
|
||||
assert_eq!(thread.model_provider, "deepseek");
|
||||
assert_eq!(thread.created_at, 0);
|
||||
assert_eq!(thread.updated_at, 0);
|
||||
assert_eq!(thread.status, ThreadStatus::Running);
|
||||
assert_eq!(thread.cwd, PathBuf::from("/tmp/project"));
|
||||
assert_eq!(thread.cli_version, "0.0.0-test");
|
||||
assert_eq!(thread.source, SessionSource::Interactive);
|
||||
assert!(thread.current_leaf_id.is_some());
|
||||
|
||||
let messages = store
|
||||
.list_messages("thread-test-1", None)
|
||||
.expect("list messages");
|
||||
assert_eq!(messages.len(), 3);
|
||||
for (i, message) in messages.iter().enumerate() {
|
||||
assert_eq!(message.thread_id, "thread-test-1");
|
||||
assert_eq!(message.role, format!("foo{}", 2 - i));
|
||||
assert_eq!(message.content, format!("bar{}", 2 - i));
|
||||
assert_eq!(message.created_at, 2 - i as i64);
|
||||
}
|
||||
|
||||
// Test idempotent
|
||||
StateStore::open(Some(path.clone())).expect("open state store");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fork() {
|
||||
let path = temp_state_path("test_fork");
|
||||
let store = StateStore::open(Some(path.clone())).expect("open state store");
|
||||
let now = chrono::Utc::now().timestamp();
|
||||
let thread = ThreadMetadata {
|
||||
id: "thread-test-1".to_string(),
|
||||
rollout_path: Some(PathBuf::from("/tmp/rollout.jsonl")),
|
||||
preview: "hello".to_string(),
|
||||
ephemeral: false,
|
||||
model_provider: "deepseek".to_string(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
status: ThreadStatus::Running,
|
||||
path: Some(PathBuf::from("/tmp/project")),
|
||||
cwd: PathBuf::from("/tmp/project"),
|
||||
cli_version: "0.0.0-test".to_string(),
|
||||
source: SessionSource::Interactive,
|
||||
name: Some("Test Thread".to_string()),
|
||||
sandbox_policy: Some("workspace-write".to_string()),
|
||||
approval_mode: Some("on-request".to_string()),
|
||||
archived: false,
|
||||
archived_at: None,
|
||||
git_sha: None,
|
||||
git_branch: None,
|
||||
git_origin_url: None,
|
||||
memory_mode: Some("extended".to_string()),
|
||||
current_leaf_id: None,
|
||||
};
|
||||
|
||||
store.upsert_thread(&thread).expect("upsert thread");
|
||||
store
|
||||
.append_message("thread-test-1", "foo0", "bar0", None)
|
||||
.expect("append message");
|
||||
store
|
||||
.append_message("thread-test-1", "foo1", "bar1", None)
|
||||
.expect("append message");
|
||||
store
|
||||
.append_message("thread-test-1", "foo2", "bar2", None)
|
||||
.expect("append message");
|
||||
store
|
||||
.append_message("thread-test-1", "foo3", "bar3", None)
|
||||
.expect("append message");
|
||||
store
|
||||
.append_message("thread-test-1", "foo4", "bar4", None)
|
||||
.expect("append message");
|
||||
|
||||
let messages = store
|
||||
.list_messages("thread-test-1", None)
|
||||
.expect("list messages");
|
||||
assert_eq!(messages.len(), 5);
|
||||
let ids = messages
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, message)| {
|
||||
assert_eq!(message.thread_id, "thread-test-1");
|
||||
assert_eq!(message.role, format!("foo{}", 4 - i));
|
||||
assert_eq!(message.content, format!("bar{}", 4 - i));
|
||||
message.id.to_string()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
store
|
||||
.fork_at_message(&ids[2], "foo5", "bar5", None)
|
||||
.expect("fork at message");
|
||||
let messages = store
|
||||
.list_messages("thread-test-1", None)
|
||||
.expect("list messages");
|
||||
assert_eq!(messages.len(), 4);
|
||||
const LIST_1: [i64; 4] = [5, 2, 1, 0];
|
||||
messages
|
||||
.iter()
|
||||
.zip(LIST_1.iter())
|
||||
.for_each(|(message, &i)| {
|
||||
assert_eq!(message.thread_id, "thread-test-1");
|
||||
assert_eq!(message.role, format!("foo{}", i));
|
||||
assert_eq!(message.content, format!("bar{}", i));
|
||||
});
|
||||
let leaves = store
|
||||
.list_leaf_messages("thread-test-1")
|
||||
.expect("list leaf messages");
|
||||
assert_eq!(leaves.len(), 2);
|
||||
|
||||
store
|
||||
.set_current_leaf_id("thread-test-1", &ids[0])
|
||||
.expect("set current leaf id");
|
||||
store
|
||||
.append_message("thread-test-1", "foo6", "bar6", None)
|
||||
.expect("append message");
|
||||
let messages = store
|
||||
.list_messages("thread-test-1", None)
|
||||
.expect("list messages");
|
||||
assert_eq!(messages.len(), 6);
|
||||
const LIST_2: [i64; 6] = [6, 4, 3, 2, 1, 0];
|
||||
messages
|
||||
.iter()
|
||||
.zip(LIST_2.iter())
|
||||
.for_each(|(message, &i)| {
|
||||
assert_eq!(message.thread_id, "thread-test-1");
|
||||
assert_eq!(message.role, format!("foo{}", i));
|
||||
assert_eq!(message.content, format!("bar{}", i));
|
||||
});
|
||||
|
||||
let leaves = store
|
||||
.list_leaf_messages("thread-test-1")
|
||||
.expect("list leaf messages");
|
||||
assert_eq!(leaves.len(), 2);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user