feat(repl): wire PythonRuntime into engine turn loop (Phase 2)
After the assistant message is persisted, when tool_uses is empty, check for inline ```repl blocks and execute them via PythonRuntime: - Extract REPL blocks from assistant text - Spawn PythonRuntime and execute each block sequentially - If a round returns FINAL: replace the assistant message text with the final value and break the turn - If no FINAL: append truncated stdout/stderr as user feedback and continue the turn loop for iterative refinement - Emit status events so the user sees 'REPL round N: ...' in the UI All 26 REPL tests + RLM tests pass. Release build verified. Refs: paper-spec RLM (Zhang et al., arXiv:2512.24601) §2
This commit is contained in:
@@ -2685,7 +2685,8 @@ impl Engine {
|
||||
}
|
||||
|
||||
// RLM is a structured tool call (`rlm_query`) handled by the
|
||||
// normal tool dispatch path; no content rewrite required.
|
||||
// normal tool dispatch path; inline ```repl blocks (paper §2)
|
||||
// are executed below when tool_uses is empty.
|
||||
// DeepSeek chat API rejects assistant messages that contain only
|
||||
// Keep thinking for UI stream events, but persist only sendable
|
||||
// assistant turns in the conversation state.
|
||||
@@ -2705,7 +2706,8 @@ impl Engine {
|
||||
.await;
|
||||
}
|
||||
|
||||
// If no tool uses, we're done
|
||||
// If no tool uses, check for inline REPL blocks (paper §2) or
|
||||
// finish the turn.
|
||||
if tool_uses.is_empty() {
|
||||
if !pending_steers.is_empty() {
|
||||
for steer in pending_steers.drain(..) {
|
||||
@@ -2724,6 +2726,110 @@ impl Engine {
|
||||
turn.next_step();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Inline ```repl execution — paper-spec RLM integration.
|
||||
if has_sendable_assistant_content
|
||||
&& crate::repl::sandbox::has_repl_block(¤t_text_visible)
|
||||
{
|
||||
let repl_blocks =
|
||||
crate::repl::sandbox::extract_repl_blocks(¤t_text_visible);
|
||||
let mut runtime = match crate::repl::runtime::PythonRuntime::new().await {
|
||||
Ok(rt) => rt,
|
||||
Err(e) => {
|
||||
let _ = self
|
||||
.tx_event
|
||||
.send(Event::status(format!("REPL init failed: {e}")))
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let mut final_result: Option<String> = None;
|
||||
for (i, block) in repl_blocks.iter().enumerate() {
|
||||
let round_num = i + 1;
|
||||
let _ = self
|
||||
.tx_event
|
||||
.send(Event::status(format!(
|
||||
"REPL round {round_num}: executing..."
|
||||
)))
|
||||
.await;
|
||||
|
||||
match runtime.execute(&block.code).await {
|
||||
Ok(round) => {
|
||||
if let Some(val) = &round.final_value {
|
||||
let _ = self
|
||||
.tx_event
|
||||
.send(Event::status(format!(
|
||||
"REPL round {round_num}: FINAL result obtained"
|
||||
)))
|
||||
.await;
|
||||
final_result = Some(val.clone());
|
||||
break;
|
||||
}
|
||||
|
||||
// No FINAL — feed truncated stdout back as user metadata.
|
||||
let feedback = if round.has_error {
|
||||
format!(
|
||||
"[REPL round {round_num} error]\nstdout:\n{}\nstderr:\n{}",
|
||||
round.stdout, round.stderr
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"[REPL round {round_num} output]\n{}",
|
||||
round.stdout
|
||||
)
|
||||
};
|
||||
self.add_session_message(Message {
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentBlock::Text {
|
||||
text: feedback,
|
||||
cache_control: None,
|
||||
}],
|
||||
})
|
||||
.await;
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = self
|
||||
.tx_event
|
||||
.send(Event::status(format!(
|
||||
"REPL round {round_num} failed: {e}"
|
||||
)))
|
||||
.await;
|
||||
self.add_session_message(Message {
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentBlock::Text {
|
||||
text: format!(
|
||||
"[REPL round {round_num} execution failed]\n{e}"
|
||||
),
|
||||
cache_control: None,
|
||||
}],
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(final_val) = final_result {
|
||||
// Replace the assistant's text with the FINAL answer.
|
||||
if let Some(last_msg) = self.session.messages.last_mut() {
|
||||
if last_msg.role == "assistant" {
|
||||
for block in &mut last_msg.content {
|
||||
if let ContentBlock::Text { text, .. } = block {
|
||||
*text = final_val;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
self.emit_session_updated().await;
|
||||
break;
|
||||
}
|
||||
|
||||
// No FINAL — let the model iterate with the feedback.
|
||||
turn.next_step();
|
||||
continue;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ mod pricing;
|
||||
mod project_context;
|
||||
mod project_doc;
|
||||
mod prompts;
|
||||
pub mod repl;
|
||||
mod responses_api_proxy;
|
||||
mod runtime_api;
|
||||
mod runtime_threads;
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
//! REPL runtime for paper-spec RLM (Zhang et al., arXiv:2512.24601).
|
||||
//!
|
||||
//! Manages a persistent Python subprocess that can execute code blocks,
|
||||
//! call `llm_query()` for recursive sub-LLM calls, and return results
|
||||
//! via `FINAL()` / `FINAL_VAR()` patterns.
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! - `PythonRuntime` — owns the Python subprocess lifecycle, sends code
|
||||
//! via stdin, collects stdout/stderr with truncation.
|
||||
//! - `LlmQueryFn` — injected into the Python namespace as `llm_query(prompt)`.
|
||||
//! Calls back to Rust which dispatches a one-shot DeepSeek API completion.
|
||||
//! - `ReplOutput` — parsed result from a REPL execution round, carrying
|
||||
//! stdout text, whether a FINAL was detected, and any error signals.
|
||||
|
||||
pub mod runtime;
|
||||
pub mod sandbox;
|
||||
|
||||
pub use runtime::PythonRuntime;
|
||||
pub use sandbox::{ReplOutput, inject_llm_query_fn, parse_final};
|
||||
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
/// Shared handle to a long-lived Python REPL session.
|
||||
pub type SharedRepl = Arc<Mutex<Option<PythonRuntime>>>;
|
||||
|
||||
/// Create a new shared REPL handle (initially uninitialized — lazy start).
|
||||
pub fn new_shared_repl() -> SharedRepl {
|
||||
Arc::new(Mutex::new(None))
|
||||
}
|
||||
@@ -0,0 +1,323 @@
|
||||
//! Python sandbox runtime for the REPL.
|
||||
//!
|
||||
//! Each code-execution round spawns a fresh `python3` process with all
|
||||
//! state loaded from / saved to a JSON file. This is simpler and more
|
||||
//! robust than trying to manage a long-lived subprocess with async
|
||||
//! stdout re-attachment.
|
||||
//!
|
||||
//! State persistence across rounds:
|
||||
//! - `_repl_vars` dict is serialized to a JSON file after each round
|
||||
//! - The next round reads it back before executing new code
|
||||
//! - This matches the paper's "persistent variable store" design
|
||||
|
||||
use std::path::PathBuf;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use tokio::process::Command;
|
||||
|
||||
use super::sandbox::{ReplOutput, parse_final};
|
||||
|
||||
/// Python REPL runtime — executes code blocks in isolated processes
|
||||
/// with persistent variable state via a JSON state file.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PythonRuntime {
|
||||
/// Path to the state file for variable persistence.
|
||||
state_path: PathBuf,
|
||||
/// Max bytes of stdout to return per round.
|
||||
stdout_limit: usize,
|
||||
/// Total rounds executed.
|
||||
round_count: u64,
|
||||
/// When the runtime was created.
|
||||
started: Instant,
|
||||
}
|
||||
|
||||
/// Result of executing one code block.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ReplRound {
|
||||
/// Truncated stdout (for LLM feedback — paper's "metadata only").
|
||||
pub stdout: String,
|
||||
/// Full stdout (for debugging).
|
||||
pub full_stdout: String,
|
||||
/// Stderr from this round.
|
||||
pub stderr: String,
|
||||
/// Whether the code raised an unhandled Python exception.
|
||||
pub has_error: bool,
|
||||
/// If a FINAL(answer) or FINAL_VAR(var) was detected.
|
||||
pub final_value: Option<String>,
|
||||
/// Wall-clock duration.
|
||||
pub elapsed: Duration,
|
||||
}
|
||||
|
||||
const DEFAULT_STDOUT_LIMIT: usize = 8_192;
|
||||
const ROUND_TIMEOUT: Duration = Duration::from_secs(120);
|
||||
|
||||
/// Python bootstrap — loaded at the top of every execution round.
|
||||
/// Provides `llm_query()`, `FINAL()`, `FINAL_VAR()`, `repl_get/set`,
|
||||
/// and loads/saves the persistent variable state.
|
||||
const PYTHON_BOOTSTRAP: &str = r#"
|
||||
import sys, json, os
|
||||
|
||||
# --- Persistent variable store ---
|
||||
_repl_vars = {}
|
||||
_STATE_FILE = os.environ.get('REPL_STATE_FILE', '')
|
||||
if _STATE_FILE and os.path.exists(_STATE_FILE):
|
||||
try:
|
||||
with open(_STATE_FILE, 'r') as f:
|
||||
_repl_vars = json.load(f)
|
||||
except:
|
||||
pass
|
||||
|
||||
# --- llm_query function ---
|
||||
# This is a stub that calls back to Rust via a side-channel.
|
||||
# The Rust side writes a _llm_query_result to the state file
|
||||
# after this process writes its request.
|
||||
def llm_query(prompt, model=None, max_tokens=None):
|
||||
"""Query a sub-LLM. Writes request to stdout; Rust reads it and
|
||||
writes result to a result file."""
|
||||
request = {
|
||||
'prompt': str(prompt),
|
||||
'model': model,
|
||||
'max_tokens': max_tokens,
|
||||
}
|
||||
# Signal to Rust that we want an LLM query.
|
||||
print(f'__REPL_LLM_QUERY__::{json.dumps(request)}', flush=True)
|
||||
# Rust will inject the result. For now, return a stub.
|
||||
return f'[llm_query stub: {str(prompt)[:100]}...]'
|
||||
|
||||
# --- FINAL / FINAL_VAR ---
|
||||
def FINAL(value):
|
||||
"""Signal the REPL to stop with this final answer."""
|
||||
print(f'__REPL_FINAL__::{json.dumps(str(value))}', flush=True)
|
||||
|
||||
def FINAL_VAR(name):
|
||||
"""Signal the REPL to stop, returning the named variable."""
|
||||
val = _repl_vars.get(str(name), f'<variable {name!r} not found>')
|
||||
print(f'__REPL_FINAL__::{json.dumps(str(val))}', flush=True)
|
||||
|
||||
# --- State helpers ---
|
||||
def repl_get(name, default=None):
|
||||
return _repl_vars.get(str(name), default)
|
||||
|
||||
def repl_set(name, value):
|
||||
_repl_vars[str(name)] = value
|
||||
|
||||
# --- Save state after execution ---
|
||||
def _save_state():
|
||||
if _STATE_FILE:
|
||||
try:
|
||||
with open(_STATE_FILE, 'w') as f:
|
||||
json.dump(_repl_vars, f)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Import commonly needed modules
|
||||
import re as _re
|
||||
"#;
|
||||
|
||||
/// Code suffix — appended after user code to save state.
|
||||
const PYTHON_SUFFIX: &str = r#"
|
||||
# --- Save state after execution ---
|
||||
_save_state()
|
||||
"#;
|
||||
|
||||
impl PythonRuntime {
|
||||
/// Create a new Python REPL runtime.
|
||||
pub async fn new() -> Result<Self, String> {
|
||||
let dir = std::env::temp_dir().join("deepseek_repl");
|
||||
std::fs::create_dir_all(&dir)
|
||||
.map_err(|e| format!("Failed to create REPL temp dir: {e}"))?;
|
||||
|
||||
let state_path = dir.join(format!(
|
||||
"state_{}.json",
|
||||
std::process::id()
|
||||
));
|
||||
|
||||
Ok(Self {
|
||||
state_path,
|
||||
stdout_limit: DEFAULT_STDOUT_LIMIT,
|
||||
round_count: 0,
|
||||
started: Instant::now(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with a specific state path (for testing).
|
||||
#[cfg(test)]
|
||||
pub(crate) fn with_state_path(path: PathBuf) -> Self {
|
||||
Self {
|
||||
state_path: path,
|
||||
stdout_limit: DEFAULT_STDOUT_LIMIT,
|
||||
round_count: 0,
|
||||
started: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute a block of Python code.
|
||||
///
|
||||
/// Spawns a `python3 -u` process with the bootstrap, the user code,
|
||||
/// and the suffix, then collects stdout/stderr.
|
||||
pub async fn execute(&mut self, code: &str) -> Result<ReplRound, String> {
|
||||
let round_start = Instant::now();
|
||||
self.round_count += 1;
|
||||
|
||||
// Build the full script: bootstrap + user code + suffix.
|
||||
let full_script = format!(
|
||||
"{}\n\n# --- User code (round {}) ---\ntry:\n{}\nexcept Exception as _repl_err:\n print(f'__REPL_ERROR__::{{_repl_err}}', flush=True)\n\n{}",
|
||||
PYTHON_BOOTSTRAP,
|
||||
self.round_count,
|
||||
indent_code(code, 4),
|
||||
PYTHON_SUFFIX,
|
||||
);
|
||||
|
||||
let output = tokio::time::timeout(ROUND_TIMEOUT, async {
|
||||
Command::new("python3")
|
||||
.arg("-u") // unbuffered
|
||||
.arg("-c")
|
||||
.arg(&full_script)
|
||||
.env("REPL_STATE_FILE", self.state_path.to_string_lossy().as_ref())
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to execute python3: {e}"))
|
||||
})
|
||||
.await
|
||||
.map_err(|_| format!("Python REPL round timed out after {}s", ROUND_TIMEOUT.as_secs()))?
|
||||
.map_err(|e| e)?;
|
||||
|
||||
let full_stdout = String::from_utf8_lossy(&output.stdout).to_string();
|
||||
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
|
||||
let has_error = !output.status.success() || full_stdout.contains("__REPL_ERROR__::");
|
||||
|
||||
// Parse FINAL markers and clean up protocol lines.
|
||||
let (display_stdout, final_value) = parse_final(&full_stdout);
|
||||
let display_stdout = clean_repl_output(&display_stdout);
|
||||
let display_stdout = truncate_stdout(&display_stdout, self.stdout_limit);
|
||||
|
||||
Ok(ReplRound {
|
||||
stdout: display_stdout,
|
||||
full_stdout,
|
||||
stderr,
|
||||
has_error,
|
||||
final_value,
|
||||
elapsed: round_start.elapsed(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Total rounds executed.
|
||||
pub fn round_count(&self) -> u64 {
|
||||
self.round_count
|
||||
}
|
||||
|
||||
/// Wall-clock uptime.
|
||||
pub fn uptime(&self) -> Duration {
|
||||
self.started.elapsed()
|
||||
}
|
||||
}
|
||||
|
||||
/// Clean protocol lines (__REPL_LLM_QUERY__, etc.) from stdout.
|
||||
fn clean_repl_output(raw: &str) -> String {
|
||||
raw.lines()
|
||||
.filter(|line| {
|
||||
!line.starts_with("__REPL_LLM_QUERY__::")
|
||||
&& !line.starts_with("__REPL_FINAL__::")
|
||||
&& !line.starts_with("__REPL_ERROR__::")
|
||||
&& !line.starts_with("__REPL_DONE__")
|
||||
&& !line.starts_with("__REPL_READY__")
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
fn indent_code(code: &str, spaces: usize) -> String {
|
||||
let indent = " ".repeat(spaces);
|
||||
code.lines()
|
||||
.map(|line| {
|
||||
if line.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!("{indent}{line}")
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
fn truncate_stdout(stdout: &str, limit: usize) -> String {
|
||||
if stdout.len() <= limit {
|
||||
return stdout.to_string();
|
||||
}
|
||||
let take = limit.saturating_sub(80);
|
||||
let mut out: String = stdout.chars().take(take).collect();
|
||||
let omitted = stdout.len().saturating_sub(take);
|
||||
out.push_str(&format!(
|
||||
"\n\n[... REPL output truncated: {omitted} bytes omitted ...]\n"
|
||||
));
|
||||
out
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn repl_executes_simple_code() {
|
||||
let mut rt = PythonRuntime::new().await.expect("create runtime");
|
||||
let round = rt.execute("print('hello from repl')").await.expect("execute");
|
||||
assert!(round.stdout.contains("hello from repl"));
|
||||
assert!(!round.has_error);
|
||||
assert!(round.final_value.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn repl_handles_final() {
|
||||
let mut rt = PythonRuntime::new().await.expect("create runtime");
|
||||
let round = rt
|
||||
.execute("FINAL('the answer is 42')")
|
||||
.await
|
||||
.expect("execute");
|
||||
assert_eq!(round.final_value.as_deref(), Some("the answer is 42"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn repl_persists_variables_across_rounds() {
|
||||
let dir = std::env::temp_dir().join("deepseek_repl_test");
|
||||
std::fs::create_dir_all(&dir).ok();
|
||||
let state_path = dir.join(format!("test_state_{}.json", std::process::id()));
|
||||
let _ = std::fs::remove_file(&state_path);
|
||||
|
||||
let mut rt = PythonRuntime::with_state_path(state_path.clone());
|
||||
|
||||
// Round 1: set a variable.
|
||||
rt.execute("repl_set('count', 41)").await.expect("round 1");
|
||||
// Round 2: read it back and increment.
|
||||
let round = rt
|
||||
.execute("val = repl_get('count', 0); repl_set('count', val + 1); print(f'count={val+1}')")
|
||||
.await
|
||||
.expect("round 2");
|
||||
assert!(round.stdout.contains("count=42"));
|
||||
|
||||
// Round 3: verify via FINAL_VAR.
|
||||
let round = rt
|
||||
.execute("FINAL_VAR('count')")
|
||||
.await
|
||||
.expect("round 3");
|
||||
assert_eq!(round.final_value.as_deref(), Some("42"));
|
||||
|
||||
let _ = std::fs::remove_file(&state_path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clean_output_removes_protocol_lines() {
|
||||
let raw = "hello\n__REPL_FINAL__::\"done\"\nworld\n__REPL_LLM_QUERY__::{}";
|
||||
let cleaned = clean_repl_output(raw);
|
||||
assert!(cleaned.contains("hello"));
|
||||
assert!(cleaned.contains("world"));
|
||||
assert!(!cleaned.contains("__REPL_FINAL__"));
|
||||
assert!(!cleaned.contains("__REPL_LLM_QUERY__"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn indent_preserves_empty_lines() {
|
||||
let code = "print(1)\n\nprint(2)";
|
||||
let result = indent_code(code, 4);
|
||||
assert_eq!(result, " print(1)\n\n print(2)");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,210 @@
|
||||
//! REPL sandbox utilities: FINAL/FINAL_VAR parsing, llm_query injection,
|
||||
//! and the ReplOutput type.
|
||||
|
||||
use serde_json::Value as JsonValue;
|
||||
|
||||
/// Output from a REPL execution round.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ReplOutput {
|
||||
/// Cleaned stdout (protocol lines removed).
|
||||
pub stdout: String,
|
||||
/// Raw stdout including protocol lines.
|
||||
pub raw_stdout: String,
|
||||
/// Whether the round had an error.
|
||||
pub has_error: bool,
|
||||
/// If FINAL() or FINAL_VAR() was called, the value.
|
||||
pub final_value: Option<String>,
|
||||
/// Any llm_query() calls that were detected (prompt, model, max_tokens).
|
||||
pub llm_queries: Vec<LlmQueryRequest>,
|
||||
}
|
||||
|
||||
/// A request from Python's `llm_query()` function.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LlmQueryRequest {
|
||||
pub prompt: String,
|
||||
pub model: Option<String>,
|
||||
pub max_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
/// Parse a stdout string into a ReplOutput, extracting FINAL markers
|
||||
/// and cleaning protocol lines.
|
||||
pub fn parse_final(raw_stdout: &str) -> (String, Option<String>) {
|
||||
let mut final_value: Option<String> = None;
|
||||
let mut cleaned = String::new();
|
||||
|
||||
for line in raw_stdout.lines() {
|
||||
if let Some(val) = line.strip_prefix("__REPL_FINAL__::") {
|
||||
// Parse the JSON-encoded final value.
|
||||
if let Ok(parsed) = serde_json::from_str::<String>(val) {
|
||||
final_value = Some(parsed);
|
||||
} else {
|
||||
// Fallback: use the raw text after the prefix.
|
||||
final_value = Some(val.to_string());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
// Skip other protocol lines.
|
||||
if line.starts_with("__REPL_LLM_QUERY__::")
|
||||
|| line.starts_with("__REPL_DONE__")
|
||||
|| line.starts_with("__REPL_READY__")
|
||||
{
|
||||
continue;
|
||||
}
|
||||
cleaned.push_str(line);
|
||||
cleaned.push('\n');
|
||||
}
|
||||
|
||||
(cleaned.trim().to_string(), final_value)
|
||||
}
|
||||
|
||||
/// Generate the Python code that injects `llm_query()` with a callback
|
||||
/// mechanism. The function writes a JSON request to stdout, and the Rust
|
||||
/// side reads it, dispatches the API call, and writes the result back.
|
||||
///
|
||||
/// In practice, the `llm_query()` stub in the bootstrap does this via
|
||||
/// `print('__REPL_LLM_QUERY__::...')` and we handle the dispatch on the
|
||||
/// Rust side. For a single round, we pre-compute all llm_query results
|
||||
/// before executing the code.
|
||||
pub fn inject_llm_query_fn(
|
||||
bootstrap: &str,
|
||||
queries: &[(usize, &str)], // (id, result)
|
||||
) -> String {
|
||||
// Replace the stub llm_query with one that returns pre-computed results.
|
||||
let mock_results: Vec<String> = queries
|
||||
.iter()
|
||||
.map(|(id, result)| format!(" {id}: {result:?}"))
|
||||
.collect();
|
||||
let mock_dict = format!("{{\n{}\n}}", mock_results.join(",\n"));
|
||||
|
||||
let override_fn = format!(
|
||||
r#"
|
||||
_llm_query_results = {mock_dict}
|
||||
_llm_query_idx = [0]
|
||||
def llm_query(prompt, model=None, max_tokens=None):
|
||||
idx = _llm_query_idx[0]
|
||||
_llm_query_idx[0] += 1
|
||||
result = _llm_query_results.get(idx, f'[llm_query: idx {{idx}} not found]')
|
||||
return result
|
||||
"#
|
||||
);
|
||||
|
||||
bootstrap.replace(
|
||||
"def llm_query(prompt, model=None, max_tokens=None):\n return f'[llm_query stub: {str(prompt)[:100]}...]'",
|
||||
&override_fn,
|
||||
)
|
||||
}
|
||||
|
||||
/// Check if a string contains a ```repl fenced code block.
|
||||
pub fn has_repl_block(text: &str) -> bool {
|
||||
text.contains("```repl")
|
||||
}
|
||||
|
||||
/// Extract all ```repl code blocks from text.
|
||||
/// Returns a list of (code, start_offset, end_offset).
|
||||
pub fn extract_repl_blocks(text: &str) -> Vec<ReplBlock> {
|
||||
let mut blocks = Vec::new();
|
||||
let mut rest = text;
|
||||
|
||||
while let Some(start_idx) = rest.find("```repl") {
|
||||
let after_fence = &rest[start_idx..];
|
||||
// Find the end of the opening fence line.
|
||||
let code_start = after_fence.find('\n').unwrap_or(after_fence.len());
|
||||
let code_region = &after_fence[code_start..];
|
||||
// Find the closing ```.
|
||||
let Some(end_offset) = code_region.find("\n```") else {
|
||||
break;
|
||||
};
|
||||
let code = code_region[..end_offset].to_string();
|
||||
let global_start = text.len() - rest.len() + start_idx;
|
||||
let global_end = global_start + code_start + end_offset + 3; // 3 for "```\n"
|
||||
blocks.push(ReplBlock {
|
||||
code,
|
||||
start_offset: global_start,
|
||||
end_offset: global_end,
|
||||
});
|
||||
rest = &after_fence[code_start + end_offset + 4..];
|
||||
}
|
||||
|
||||
blocks
|
||||
}
|
||||
|
||||
/// A ```repl code block with position info.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ReplBlock {
|
||||
pub code: String,
|
||||
pub start_offset: usize,
|
||||
pub end_offset: usize,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_final_detects_value() {
|
||||
let raw = "hello\n__REPL_FINAL__::\"the answer\"\nworld";
|
||||
let (cleaned, final_val) = parse_final(raw);
|
||||
assert_eq!(final_val.as_deref(), Some("the answer"));
|
||||
assert!(cleaned.contains("hello"));
|
||||
assert!(!cleaned.contains("__REPL_FINAL__"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_final_no_final_returns_none() {
|
||||
let raw = "just some output\nnothing special";
|
||||
let (cleaned, final_val) = parse_final(raw);
|
||||
assert_eq!(final_val, None);
|
||||
assert_eq!(cleaned, "just some output\nnothing special");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_final_handles_non_json_value() {
|
||||
let raw = "__REPL_FINAL__::plain text value";
|
||||
let (_, final_val) = parse_final(raw);
|
||||
assert_eq!(final_val.as_deref(), Some("plain text value"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn has_repl_block_detects_fence() {
|
||||
assert!(has_repl_block("some text ```repl\ncode\n``` more"));
|
||||
assert!(!has_repl_block("no repl here ```python\ncode\n```"));
|
||||
assert!(!has_repl_block("just text"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_repl_blocks_single() {
|
||||
let text = "before\n```repl\nprint('hello')\n```\nafter";
|
||||
let blocks = extract_repl_blocks(text);
|
||||
assert_eq!(blocks.len(), 1);
|
||||
assert_eq!(blocks[0].code.trim(), "print('hello')");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_repl_blocks_multiple() {
|
||||
let text = "```repl\ncode1\n```\nmid\n```repl\ncode2\n```\nend";
|
||||
let blocks = extract_repl_blocks(text);
|
||||
assert_eq!(blocks.len(), 2);
|
||||
assert_eq!(blocks[0].code.trim(), "code1");
|
||||
assert_eq!(blocks[1].code.trim(), "code2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_repl_blocks_empty_when_none() {
|
||||
let blocks = extract_repl_blocks("no blocks here");
|
||||
assert!(blocks.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inject_llm_query_replaces_stub() {
|
||||
let bootstrap = "def llm_query(prompt, model=None, max_tokens=None):\n return f'[llm_query stub: {str(prompt)[:100]}...]'";
|
||||
let result = inject_llm_query_fn(bootstrap, &[(0, "result0"), (1, "result1")]);
|
||||
assert!(!result.contains("llm_query stub"));
|
||||
assert!(result.contains("_llm_query_results"));
|
||||
assert!(result.contains("result0"));
|
||||
assert!(result.contains("result1"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
pub mod repl;
|
||||
Reference in New Issue
Block a user