Files
codewhale/crates/state/tests/parity_state.rs
T

477 lines
16 KiB
Rust

use std::path::PathBuf;
use codewhale_state::{SessionSource, StateStore, ThreadListFilters, ThreadMetadata, ThreadStatus};
use rusqlite::Connection;
fn temp_state_path(label: &str) -> PathBuf {
std::env::temp_dir().join(format!(
"deepseek_state_test_{}_{}_{}.db",
label,
std::process::id(),
chrono::Utc::now().timestamp_nanos_opt().unwrap_or(0)
))
}
fn assert_workflow_trace_schema(conn: &Connection) {
let user_version: u32 = conn
.query_row("PRAGMA user_version;", [], |row| row.get(0))
.expect("read user_version");
assert_eq!(user_version, 3);
for table in [
"workflow_runs",
"branch_runs",
"leaf_runs",
"control_node_runs",
"teacher_candidates",
"thread_goals",
] {
let exists: bool = conn
.query_row(
"SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = ?1)",
[table],
|row| row.get(0),
)
.unwrap_or_else(|err| panic!("read sqlite_master for {table}: {err}"));
assert!(exists, "missing workflow trace table {table}");
}
}
#[test]
fn upsert_and_resume_thread_metadata() {
let path = temp_state_path("upsert_resume");
let store = StateStore::open(Some(path.clone())).expect("open state store");
let now = chrono::Utc::now().timestamp();
let thread = ThreadMetadata {
id: "thread-test-1".to_string(),
rollout_path: Some(PathBuf::from("/tmp/rollout.jsonl")),
preview: "hello".to_string(),
ephemeral: false,
model_provider: "deepseek".to_string(),
created_at: now,
updated_at: now,
status: ThreadStatus::Running,
path: Some(PathBuf::from("/tmp/project")),
cwd: PathBuf::from("/tmp/project"),
cli_version: "0.0.0-test".to_string(),
source: SessionSource::Interactive,
name: Some("Test Thread".to_string()),
sandbox_policy: Some("workspace-write".to_string()),
approval_mode: Some("on-request".to_string()),
archived: false,
archived_at: None,
git_sha: None,
git_branch: None,
git_origin_url: None,
memory_mode: Some("extended".to_string()),
current_leaf_id: None,
};
store.upsert_thread(&thread).expect("upsert thread");
let loaded = store
.get_thread("thread-test-1")
.expect("read thread")
.expect("thread must exist");
assert_eq!(loaded.id, "thread-test-1");
assert_eq!(loaded.name.as_deref(), Some("Test Thread"));
assert_eq!(loaded.memory_mode.as_deref(), Some("extended"));
assert_eq!(
loaded.rollout_path,
Some(PathBuf::from("/tmp/rollout.jsonl"))
);
store
.mark_archived("thread-test-1")
.expect("archive thread");
let archived = store
.get_thread("thread-test-1")
.expect("read archived thread")
.expect("thread exists after archive");
assert!(archived.archived);
let listed = store
.list_threads(ThreadListFilters {
include_archived: true,
limit: Some(10),
})
.expect("list threads");
assert!(!listed.is_empty());
}
#[test]
fn init_schema_migration() {
let path = temp_state_path("init_schema_migration");
let conn = Connection::open(&path).expect("open state db");
conn.execute_batch(
r#"
CREATE TABLE IF NOT EXISTS threads (
id TEXT PRIMARY KEY,
rollout_path TEXT,
preview TEXT NOT NULL,
ephemeral INTEGER NOT NULL,
model_provider TEXT NOT NULL,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL,
status TEXT NOT NULL,
path TEXT,
cwd TEXT NOT NULL,
cli_version TEXT NOT NULL,
source TEXT NOT NULL,
title TEXT,
sandbox_policy TEXT,
approval_mode TEXT,
archived INTEGER NOT NULL DEFAULT 0,
archived_at INTEGER,
git_sha TEXT,
git_branch TEXT,
git_origin_url TEXT,
memory_mode TEXT
);
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
thread_id TEXT NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
item_json TEXT,
created_at INTEGER NOT NULL,
FOREIGN KEY(thread_id) REFERENCES threads(id) ON DELETE CASCADE
);
INSERT INTO threads (
id, preview, ephemeral, model_provider, created_at, updated_at, status, cwd, cli_version, source, archived
)
VALUES (
'thread-test-1', 'hello', false, 'deepseek', 0, 0, 'running', '/tmp/project', '0.0.0-test', 'interactive', false
);
INSERT INTO messages (thread_id, role, content, created_at) VALUES
('thread-test-1', 'foo0', 'bar0', 0),
('thread-test-1', 'foo1', 'bar1', 1),
('thread-test-1', 'foo2', 'bar2', 2);
"#,
)
.expect("init schema migration");
let store = StateStore::open(Some(path.clone())).expect("open state store");
let thread = store
.get_thread("thread-test-1")
.expect("read thread")
.unwrap();
assert_eq!(thread.id, "thread-test-1");
assert_eq!(thread.preview, "hello");
assert!(!thread.ephemeral);
assert_eq!(thread.model_provider, "deepseek");
assert_eq!(thread.created_at, 0);
assert_eq!(thread.updated_at, 0);
assert_eq!(thread.status, ThreadStatus::Running);
assert_eq!(thread.cwd, PathBuf::from("/tmp/project"));
assert_eq!(thread.cli_version, "0.0.0-test");
assert_eq!(thread.source, SessionSource::Interactive);
assert!(thread.current_leaf_id.is_some());
let messages = store
.list_messages("thread-test-1", None)
.expect("list messages");
assert_eq!(messages.len(), 3);
for (i, message) in messages.iter().enumerate() {
assert_eq!(message.thread_id, "thread-test-1");
assert_eq!(message.role, format!("foo{}", i));
assert_eq!(message.content, format!("bar{}", i));
assert_eq!(message.created_at, i as i64);
}
// Test idempotent
StateStore::open(Some(path.clone())).expect("open state store");
}
#[test]
fn fresh_schema_includes_workflow_trace_tables() {
let path = temp_state_path("fresh_schema_includes_workflow_trace_tables");
StateStore::open(Some(path.clone())).expect("open state store");
let conn = Connection::open(&path).expect("open state db");
assert_workflow_trace_schema(&conn);
}
#[test]
fn v1_schema_migrates_workflow_trace_tables() {
let path = temp_state_path("v1_schema_migrates_workflow_trace_tables");
let conn = Connection::open(&path).expect("open state db");
conn.execute_batch(
r#"
CREATE TABLE threads (
id TEXT PRIMARY KEY,
rollout_path TEXT,
preview TEXT NOT NULL,
ephemeral INTEGER NOT NULL,
model_provider TEXT NOT NULL,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL,
status TEXT NOT NULL,
path TEXT,
cwd TEXT NOT NULL,
cli_version TEXT NOT NULL,
source TEXT NOT NULL,
title TEXT,
sandbox_policy TEXT,
approval_mode TEXT,
archived INTEGER NOT NULL DEFAULT 0,
archived_at INTEGER,
git_sha TEXT,
git_branch TEXT,
git_origin_url TEXT,
memory_mode TEXT,
current_leaf_id INTEGER
);
CREATE TABLE messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
thread_id TEXT NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
item_json TEXT,
created_at INTEGER NOT NULL,
parent_entry_id INTEGER
);
CREATE TABLE checkpoints (
thread_id TEXT NOT NULL,
checkpoint_id TEXT NOT NULL,
state_json TEXT NOT NULL,
created_at INTEGER NOT NULL,
PRIMARY KEY(thread_id, checkpoint_id)
);
CREATE TABLE jobs (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
status TEXT NOT NULL,
progress INTEGER,
detail TEXT,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL
);
CREATE TABLE thread_dynamic_tools (
thread_id TEXT NOT NULL,
position INTEGER NOT NULL,
name TEXT NOT NULL,
description TEXT,
input_schema TEXT NOT NULL,
PRIMARY KEY (thread_id, position)
);
INSERT INTO threads (
id, preview, ephemeral, model_provider, created_at, updated_at, status, cwd, cli_version, source, archived
)
VALUES (
'thread-test-1', 'hello', false, 'deepseek', 0, 0, 'running', '/tmp/project', '0.0.0-test', 'interactive', false
);
PRAGMA user_version = 1;
"#,
)
.expect("create v1 schema");
drop(conn);
let store = StateStore::open(Some(path.clone())).expect("open state store");
let thread = store
.get_thread("thread-test-1")
.expect("read thread")
.expect("thread survives migration");
assert_eq!(thread.preview, "hello");
let conn = Connection::open(&path).expect("open state db");
assert_workflow_trace_schema(&conn);
}
#[test]
fn init_schema_migration_same_second_messages() {
let path = temp_state_path("init_schema_migration_same_second_messages");
let conn = Connection::open(&path).expect("open state db");
conn.execute_batch(
r#"
CREATE TABLE IF NOT EXISTS threads (
id TEXT PRIMARY KEY,
rollout_path TEXT,
preview TEXT NOT NULL,
ephemeral INTEGER NOT NULL,
model_provider TEXT NOT NULL,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL,
status TEXT NOT NULL,
path TEXT,
cwd TEXT NOT NULL,
cli_version TEXT NOT NULL,
source TEXT NOT NULL,
title TEXT,
sandbox_policy TEXT,
approval_mode TEXT,
archived INTEGER NOT NULL DEFAULT 0,
archived_at INTEGER,
git_sha TEXT,
git_branch TEXT,
git_origin_url TEXT,
memory_mode TEXT
);
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
thread_id TEXT NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
item_json TEXT,
created_at INTEGER NOT NULL,
FOREIGN KEY(thread_id) REFERENCES threads(id) ON DELETE CASCADE
);
INSERT INTO threads (
id, preview, ephemeral, model_provider, created_at, updated_at, status, cwd, cli_version, source, archived
)
VALUES (
'thread-test-2', 'hello', false, 'deepseek', 0, 0, 'running', '/tmp/project', '0.0.0-test', 'interactive', false
);
INSERT INTO messages (thread_id, role, content, created_at) VALUES
('thread-test-2', 'foo0', 'bar0', 123),
('thread-test-2', 'foo1', 'bar1', 123),
('thread-test-2', 'foo2', 'bar2', 123),
('thread-test-2', 'foo3', 'bar3', 123);
"#,
)
.expect("init schema migration");
let store = StateStore::open(Some(path.clone())).expect("open state store");
let messages = store
.list_messages("thread-test-2", None)
.expect("list messages");
assert_eq!(messages.len(), 4);
for (i, message) in messages.iter().enumerate() {
assert_eq!(message.thread_id, "thread-test-2");
assert_eq!(message.role, format!("foo{}", i));
assert_eq!(message.content, format!("bar{}", i));
assert_eq!(message.created_at, 123);
}
assert_eq!(messages[0].parent_entry_id, None);
assert_eq!(messages[1].parent_entry_id, Some(messages[0].id));
assert_eq!(messages[2].parent_entry_id, Some(messages[1].id));
assert_eq!(messages[3].parent_entry_id, Some(messages[2].id));
// Test idempotent reopen after same-second parent links are migrated.
StateStore::open(Some(path.clone())).expect("open state store - idempotent");
}
#[test]
fn test_fork() {
let path = temp_state_path("test_fork");
let store = StateStore::open(Some(path.clone())).expect("open state store");
let now = chrono::Utc::now().timestamp();
let thread = ThreadMetadata {
id: "thread-test-1".to_string(),
rollout_path: Some(PathBuf::from("/tmp/rollout.jsonl")),
preview: "hello".to_string(),
ephemeral: false,
model_provider: "deepseek".to_string(),
created_at: now,
updated_at: now,
status: ThreadStatus::Running,
path: Some(PathBuf::from("/tmp/project")),
cwd: PathBuf::from("/tmp/project"),
cli_version: "0.0.0-test".to_string(),
source: SessionSource::Interactive,
name: Some("Test Thread".to_string()),
sandbox_policy: Some("workspace-write".to_string()),
approval_mode: Some("on-request".to_string()),
archived: false,
archived_at: None,
git_sha: None,
git_branch: None,
git_origin_url: None,
memory_mode: Some("extended".to_string()),
current_leaf_id: None,
};
store.upsert_thread(&thread).expect("upsert thread");
store
.append_message("thread-test-1", "foo0", "bar0", None)
.expect("append message");
store
.append_message("thread-test-1", "foo1", "bar1", None)
.expect("append message");
store
.append_message("thread-test-1", "foo2", "bar2", None)
.expect("append message");
store
.append_message("thread-test-1", "foo3", "bar3", None)
.expect("append message");
store
.append_message("thread-test-1", "foo4", "bar4", None)
.expect("append message");
let messages = store
.list_messages("thread-test-1", None)
.expect("list messages");
assert_eq!(messages.len(), 5);
let ids = messages
.iter()
.enumerate()
.map(|(i, message)| {
assert_eq!(message.thread_id, "thread-test-1");
assert_eq!(message.role, format!("foo{}", i));
assert_eq!(message.content, format!("bar{}", i));
message.id.to_string()
})
.collect::<Vec<_>>();
store.upsert_thread(&thread).expect("upsert thread");
store
.fork_at_message(&ids[2], "foo5", "bar5", None)
.expect("fork at message");
let messages = store
.list_messages("thread-test-1", None)
.expect("list messages");
assert_eq!(messages.len(), 4);
const LIST_1: [i64; 4] = [0, 1, 2, 5];
messages
.iter()
.zip(LIST_1.iter())
.for_each(|(message, &i)| {
assert_eq!(message.thread_id, "thread-test-1");
assert_eq!(message.role, format!("foo{}", i));
assert_eq!(message.content, format!("bar{}", i));
});
let leaves = store
.list_leaf_messages("thread-test-1")
.expect("list leaf messages");
assert_eq!(leaves.len(), 2);
store
.set_current_leaf_id("thread-test-1", &ids[4])
.expect("set current leaf id");
store
.append_message("thread-test-1", "foo6", "bar6", None)
.expect("append message");
let messages = store
.list_messages("thread-test-1", None)
.expect("list messages");
assert_eq!(messages.len(), 6);
const LIST_2: [i64; 6] = [0, 1, 2, 3, 4, 6];
messages
.iter()
.zip(LIST_2.iter())
.for_each(|(message, &i)| {
assert_eq!(message.thread_id, "thread-test-1");
assert_eq!(message.role, format!("foo{}", i));
assert_eq!(message.content, format!("bar{}", i));
});
let leaves = store
.list_leaf_messages("thread-test-1")
.expect("list leaf messages");
assert_eq!(leaves.len(), 2);
store
.clear_messages("thread-test-1")
.expect("clear messages");
let leaves = store
.list_leaf_messages("thread-test-1")
.expect("list leaf messages");
assert_eq!(leaves.len(), 0);
let thread = store
.get_thread("thread-test-1")
.expect("get thread")
.unwrap();
assert!(thread.current_leaf_id.is_none());
}