diff --git a/crates/state/src/lib.rs b/crates/state/src/lib.rs index 7347245d..589ae423 100644 --- a/crates/state/src/lib.rs +++ b/crates/state/src/lib.rs @@ -168,6 +168,7 @@ impl StateStore { if user_version == 0 { conn.execute_batch( r#" + BEGIN; CREATE TABLE IF NOT EXISTS threads ( id TEXT PRIMARY KEY, rollout_path TEXT, @@ -244,7 +245,7 @@ impl StateStore { SELECT m2.id FROM messages m2 WHERE m2.created_at < messages.created_at AND m2.thread_id = messages.thread_id - ORDER BY m2.created_at DESC + ORDER BY m2.id DESC LIMIT 1 ); CREATE INDEX idx_messages_parent_entry_id ON messages(parent_entry_id); @@ -256,11 +257,12 @@ impl StateStore { SELECT m.id FROM messages m WHERE m.thread_id = threads.id - ORDER BY m.created_at DESC + ORDER BY m.id DESC LIMIT 1 ); PRAGMA user_version = 1; + COMMIT; "#, ) .context("failed to initialize thread schema")?; @@ -589,7 +591,7 @@ impl StateStore { tx.commit() .context("failed to commit append message transaction")?; - Ok(conn.last_insert_rowid()) + Ok(next_leaf_id) } pub fn list_messages( @@ -619,7 +621,7 @@ impl StateStore { WHERE a.depth < ?2 ) SELECT id, thread_id, role, content, item_json, created_at, parent_entry_id FROM ancestors - ORDER BY depth ASC + ORDER BY depth DESC "# ) .context("failed to prepare message listing query")?; @@ -668,7 +670,7 @@ impl StateStore { .transaction() .context("failed to begin fork message transaction")?; - let thread_id: Option = tx + let thread_id: String = tx .query_row( "SELECT thread_id FROM messages WHERE id = ?1", params![message_id], diff --git a/crates/state/tests/parity_state.rs b/crates/state/tests/parity_state.rs index cee9192a..96ae2b7d 100644 --- a/crates/state/tests/parity_state.rs +++ b/crates/state/tests/parity_state.rs @@ -148,9 +148,9 @@ fn init_schema_migration() { assert_eq!(messages.len(), 3); for (i, message) in messages.iter().enumerate() { assert_eq!(message.thread_id, "thread-test-1"); - assert_eq!(message.role, format!("foo{}", 2 - i)); - assert_eq!(message.content, format!("bar{}", 2 - i)); - assert_eq!(message.created_at, 2 - i as i64); + assert_eq!(message.role, format!("foo{}", i)); + assert_eq!(message.content, format!("bar{}", i)); + assert_eq!(message.created_at, i as i64); } // Test idempotent @@ -213,8 +213,8 @@ fn test_fork() { .enumerate() .map(|(i, message)| { assert_eq!(message.thread_id, "thread-test-1"); - assert_eq!(message.role, format!("foo{}", 4 - i)); - assert_eq!(message.content, format!("bar{}", 4 - i)); + assert_eq!(message.role, format!("foo{}", i)); + assert_eq!(message.content, format!("bar{}", i)); message.id.to_string() }) .collect::>(); @@ -226,7 +226,7 @@ fn test_fork() { .list_messages("thread-test-1", None) .expect("list messages"); assert_eq!(messages.len(), 4); - const LIST_1: [i64; 4] = [5, 2, 1, 0]; + const LIST_1: [i64; 4] = [0, 1, 2, 5]; messages .iter() .zip(LIST_1.iter()) @@ -241,7 +241,7 @@ fn test_fork() { assert_eq!(leaves.len(), 2); store - .set_current_leaf_id("thread-test-1", &ids[0]) + .set_current_leaf_id("thread-test-1", &ids[4]) .expect("set current leaf id"); store .append_message("thread-test-1", "foo6", "bar6", None) @@ -250,7 +250,7 @@ fn test_fork() { .list_messages("thread-test-1", None) .expect("list messages"); assert_eq!(messages.len(), 6); - const LIST_2: [i64; 6] = [6, 4, 3, 2, 1, 0]; + const LIST_2: [i64; 6] = [0, 1, 2, 3, 4, 6]; messages .iter() .zip(LIST_2.iter())