fix: harden Xiaomi MiMo speech flow
This commit is contained in:
+73
-14
@@ -476,6 +476,31 @@ fn parse_speech_audio_response(payload: &Value) -> Result<(Vec<u8>, Option<Strin
|
||||
Ok((audio_bytes, transcript))
|
||||
}
|
||||
|
||||
fn build_speech_synthesis_body(
|
||||
model: &str,
|
||||
text: &str,
|
||||
instruction: Option<&str>,
|
||||
audio: Value,
|
||||
) -> Value {
|
||||
let mut messages = Vec::new();
|
||||
if let Some(instruction) = instruction.map(str::trim).filter(|value| !value.is_empty()) {
|
||||
messages.push(json!({
|
||||
"role": "user",
|
||||
"content": instruction,
|
||||
}));
|
||||
}
|
||||
messages.push(json!({
|
||||
"role": "assistant",
|
||||
"content": text,
|
||||
}));
|
||||
|
||||
json!({
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"audio": audio,
|
||||
})
|
||||
}
|
||||
|
||||
// === DeepSeekClient ===
|
||||
|
||||
/// Returns true when DEEPSEEK_FORCE_HTTP1 is set to a truthy value
|
||||
@@ -773,20 +798,7 @@ impl DeepSeekClient {
|
||||
audio["voice"] = json!(voice);
|
||||
}
|
||||
|
||||
let body = json!({
|
||||
"model": model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": instruction.unwrap_or(""),
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": text,
|
||||
}
|
||||
],
|
||||
"audio": audio,
|
||||
});
|
||||
let body = build_speech_synthesis_body(&model, &text, instruction, audio);
|
||||
|
||||
let url = api_url(&self.base_url, "chat/completions");
|
||||
let response = self
|
||||
@@ -1366,6 +1378,53 @@ mod tests {
|
||||
assert_eq!(transcript, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn speech_synthesis_body_omits_user_message_without_instruction() {
|
||||
let body =
|
||||
build_speech_synthesis_body("mimo-v2.5-tts", "hello", None, json!({"format": "wav"}));
|
||||
let messages = body["messages"].as_array().expect("messages array");
|
||||
|
||||
assert_eq!(messages.len(), 1);
|
||||
assert_eq!(messages[0]["role"], "assistant");
|
||||
assert_eq!(messages[0]["content"], "hello");
|
||||
assert!(
|
||||
messages
|
||||
.iter()
|
||||
.all(|message| message["content"].as_str() != Some(""))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn speech_synthesis_body_ignores_blank_instruction() {
|
||||
let body = build_speech_synthesis_body(
|
||||
"mimo-v2.5-tts",
|
||||
"hello",
|
||||
Some(" \t\n "),
|
||||
json!({"format": "wav"}),
|
||||
);
|
||||
let messages = body["messages"].as_array().expect("messages array");
|
||||
|
||||
assert_eq!(messages.len(), 1);
|
||||
assert_eq!(messages[0]["role"], "assistant");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn speech_synthesis_body_includes_non_empty_instruction_first() {
|
||||
let body = build_speech_synthesis_body(
|
||||
"mimo-v2.5-tts-voicedesign",
|
||||
"hello",
|
||||
Some("warm and calm"),
|
||||
json!({"format": "wav"}),
|
||||
);
|
||||
let messages = body["messages"].as_array().expect("messages array");
|
||||
|
||||
assert_eq!(messages.len(), 2);
|
||||
assert_eq!(messages[0]["role"], "user");
|
||||
assert_eq!(messages[0]["content"], "warm and calm");
|
||||
assert_eq!(messages[1]["role"], "assistant");
|
||||
assert_eq!(messages[1]["content"], "hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_name_roundtrip_dot() {
|
||||
let original = "multi_tool_use.parallel";
|
||||
|
||||
@@ -992,6 +992,7 @@ impl Engine {
|
||||
)
|
||||
.with_max_spawn_depth(self.config.max_spawn_depth)
|
||||
.with_step_api_timeout(self.config.subagent_api_timeout)
|
||||
.with_speech_output_dir(self.config.speech_output_dir.clone())
|
||||
.with_mcp_pool(mcp_pool)
|
||||
.background_runtime();
|
||||
let route = resolve_subagent_assignment_route(
|
||||
@@ -1496,6 +1497,7 @@ impl Engine {
|
||||
)
|
||||
.with_max_spawn_depth(self.config.max_spawn_depth)
|
||||
.with_step_api_timeout(self.config.subagent_api_timeout)
|
||||
.with_speech_output_dir(self.config.speech_output_dir.clone())
|
||||
.with_mcp_pool(mcp_pool.clone())
|
||||
.with_parent_completion_tx(self.tx_subagent_completion.clone());
|
||||
if let Some(context) = fork_context_for_runtime.clone() {
|
||||
|
||||
+61
-117
@@ -6,7 +6,6 @@ use std::process::{Command, Stdio};
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{Context, Result, anyhow, bail};
|
||||
use base64::{Engine as _, engine::general_purpose};
|
||||
use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
|
||||
use clap_complete::{Shell, generate};
|
||||
use dotenvy::dotenv;
|
||||
@@ -3568,7 +3567,12 @@ async fn run_models(config: &Config, args: ModelsArgs) -> Result<()> {
|
||||
|
||||
async fn run_speech(config: &Config, args: SpeechArgs) -> Result<()> {
|
||||
use crate::client::{DeepSeekClient, SpeechSynthesisRequest};
|
||||
use crate::config::{ApiProvider, normalize_model_name_for_provider};
|
||||
use crate::config::ApiProvider;
|
||||
use crate::tools::speech::{
|
||||
DEFAULT_VOICE, SPEECH_MODEL_EXAMPLES, combine_speech_instructions,
|
||||
default_speech_output_name, describe_speech_voice, encode_voice_clone_sample_data_uri,
|
||||
infer_speech_model, normalize_speech_format,
|
||||
};
|
||||
|
||||
let SpeechArgs {
|
||||
text,
|
||||
@@ -3600,24 +3604,16 @@ async fn run_speech(config: &Config, args: SpeechArgs) -> Result<()> {
|
||||
if clone_voice.is_some() && voice.is_some() {
|
||||
bail!("Use either --clone-voice or --voice for cloned voice data, not both");
|
||||
}
|
||||
let model = match model {
|
||||
Some(value) => {
|
||||
normalize_model_name_for_provider(ApiProvider::XiaomiMimo, &value).unwrap_or(value)
|
||||
}
|
||||
None => {
|
||||
if clone_voice.is_some() || voice_is_data_uri {
|
||||
"mimo-v2.5-tts-voiceclone".to_string()
|
||||
} else if voice_prompt.is_some() {
|
||||
"mimo-v2.5-tts-voicedesign".to_string()
|
||||
} else {
|
||||
"mimo-v2.5-tts".to_string()
|
||||
}
|
||||
}
|
||||
};
|
||||
let model = infer_speech_model(
|
||||
model.as_deref(),
|
||||
clone_voice.is_some() || voice_is_data_uri,
|
||||
voice_prompt.is_some(),
|
||||
);
|
||||
let model_lower = model.to_ascii_lowercase();
|
||||
if !model_lower.contains("tts") {
|
||||
bail!(
|
||||
"speech requires a TTS model (examples: mimo-v2.5-tts, mimo-v2.5-tts-voicedesign, mimo-v2.5-tts-voiceclone); got {model}"
|
||||
"speech requires a TTS model (examples: {}); got {model}",
|
||||
SPEECH_MODEL_EXAMPLES.join(", ")
|
||||
);
|
||||
}
|
||||
let is_voice_design = model_lower.contains("voicedesign");
|
||||
@@ -3635,7 +3631,7 @@ async fn run_speech(config: &Config, args: SpeechArgs) -> Result<()> {
|
||||
}
|
||||
|
||||
let voice = if let Some(clone_path) = clone_voice {
|
||||
Some(encode_voice_clone_data_uri(&clone_path)?)
|
||||
Some(encode_voice_clone_sample_data_uri(&clone_path)?)
|
||||
} else if is_voice_design {
|
||||
None
|
||||
} else if let Some(value) = voice.filter(|value| !value.trim().is_empty()) {
|
||||
@@ -3643,16 +3639,17 @@ async fn run_speech(config: &Config, args: SpeechArgs) -> Result<()> {
|
||||
} else if is_voice_clone {
|
||||
bail!("mimo-v2.5-tts-voiceclone requires --clone-voice <mp3|wav> or --voice <data-uri>");
|
||||
} else {
|
||||
Some("mimo_default".to_string())
|
||||
Some(DEFAULT_VOICE.to_string())
|
||||
};
|
||||
let format = normalize_speech_format(&format).with_context(|| {
|
||||
format!("Unsupported speech format '{format}' (allowed: wav, mp3, pcm16)")
|
||||
})?;
|
||||
let output = resolve_speech_output_path(
|
||||
output,
|
||||
output_dir.or_else(|| config.speech_output_dir()),
|
||||
&format,
|
||||
);
|
||||
let output = output.unwrap_or_else(|| {
|
||||
output_dir
|
||||
.or_else(|| config.speech_output_dir())
|
||||
.unwrap_or_default()
|
||||
.join(default_speech_output_name(&format))
|
||||
});
|
||||
|
||||
let client = DeepSeekClient::new(config)?;
|
||||
let response = client
|
||||
@@ -3699,99 +3696,12 @@ async fn run_speech(config: &Config, args: SpeechArgs) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn combine_speech_instructions(
|
||||
instruction: Option<String>,
|
||||
voice_prompt: Option<String>,
|
||||
) -> Option<String> {
|
||||
match (instruction, voice_prompt) {
|
||||
(Some(instruction), Some(voice_prompt)) => {
|
||||
let instruction = instruction.trim();
|
||||
let voice_prompt = voice_prompt.trim();
|
||||
if instruction.is_empty() {
|
||||
Some(voice_prompt.to_string()).filter(|value| !value.is_empty())
|
||||
} else if voice_prompt.is_empty() {
|
||||
Some(instruction.to_string()).filter(|value| !value.is_empty())
|
||||
} else {
|
||||
Some(format!("{voice_prompt}\n\n{instruction}"))
|
||||
}
|
||||
}
|
||||
(Some(value), None) | (None, Some(value)) => {
|
||||
let value = value.trim().to_string();
|
||||
if value.is_empty() { None } else { Some(value) }
|
||||
}
|
||||
(None, None) => None,
|
||||
}
|
||||
}
|
||||
|
||||
const VOICE_CLONE_BASE64_MAX_BYTES: usize = 10 * 1024 * 1024;
|
||||
|
||||
fn normalize_speech_format(format: &str) -> Option<String> {
|
||||
let normalized = format.trim().to_ascii_lowercase();
|
||||
match normalized.as_str() {
|
||||
"wav" | "mp3" | "pcm16" => Some(normalized),
|
||||
"pcm" => Some("pcm16".to_string()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn default_speech_output_name(format: &str) -> String {
|
||||
format!(
|
||||
"speech.{}",
|
||||
normalize_speech_format(format).as_deref().unwrap_or("wav")
|
||||
)
|
||||
}
|
||||
|
||||
fn resolve_speech_output_path(
|
||||
output: Option<PathBuf>,
|
||||
output_dir: Option<PathBuf>,
|
||||
format: &str,
|
||||
) -> PathBuf {
|
||||
output.unwrap_or_else(|| {
|
||||
output_dir
|
||||
.unwrap_or_default()
|
||||
.join(default_speech_output_name(format))
|
||||
})
|
||||
}
|
||||
|
||||
fn encode_voice_clone_data_uri(path: &Path) -> Result<String> {
|
||||
let bytes = std::fs::read(path)
|
||||
.with_context(|| format!("Failed to read voice clone sample {}", path.display()))?;
|
||||
let base64_audio = general_purpose::STANDARD.encode(bytes);
|
||||
if base64_audio.len() > VOICE_CLONE_BASE64_MAX_BYTES {
|
||||
bail!(
|
||||
"Voice clone sample is too large after base64 encoding ({} bytes > 10 MB)",
|
||||
base64_audio.len()
|
||||
);
|
||||
}
|
||||
|
||||
let extension = path
|
||||
.extension()
|
||||
.and_then(|value| value.to_str())
|
||||
.unwrap_or_default()
|
||||
.to_ascii_lowercase();
|
||||
let mime = match extension.as_str() {
|
||||
"mp3" => "audio/mpeg",
|
||||
"wav" => "audio/wav",
|
||||
other => bail!(
|
||||
"Unsupported voice clone sample extension '{}'. Use .mp3 or .wav.",
|
||||
other
|
||||
),
|
||||
};
|
||||
|
||||
Ok(format!("data:{mime};base64,{base64_audio}"))
|
||||
}
|
||||
|
||||
fn describe_speech_voice(voice: &str) -> String {
|
||||
if voice.starts_with("data:") {
|
||||
"embedded voice clone sample".to_string()
|
||||
} else {
|
||||
voice.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod speech_cli_tests {
|
||||
use super::*;
|
||||
use crate::tools::speech::{
|
||||
default_speech_output_name, infer_speech_model, normalize_speech_format,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn normalizes_documented_speech_formats() {
|
||||
@@ -3804,18 +3714,52 @@ mod speech_cli_tests {
|
||||
#[test]
|
||||
fn default_speech_output_tracks_requested_format() {
|
||||
assert_eq!(
|
||||
resolve_speech_output_path(None, None, "mp3"),
|
||||
PathBuf::from(default_speech_output_name("mp3")),
|
||||
PathBuf::from("speech.mp3")
|
||||
);
|
||||
assert_eq!(
|
||||
resolve_speech_output_path(None, Some(PathBuf::from("audio")), "pcm"),
|
||||
PathBuf::from("audio").join(default_speech_output_name("pcm")),
|
||||
PathBuf::from("audio").join("speech.pcm16")
|
||||
);
|
||||
assert_eq!(
|
||||
resolve_speech_output_path(Some(PathBuf::from("custom.wav")), None, "mp3"),
|
||||
Some(PathBuf::from("custom.wav"))
|
||||
.unwrap_or_else(|| PathBuf::from(default_speech_output_name("mp3"))),
|
||||
PathBuf::from("custom.wav")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn speech_command_parses_cli_passthrough_smoke() {
|
||||
let cli = Cli::try_parse_from([
|
||||
"codewhale-tui",
|
||||
"speech",
|
||||
"hello",
|
||||
"--model",
|
||||
"tts",
|
||||
"--format",
|
||||
"pcm",
|
||||
"--output-dir",
|
||||
"audio",
|
||||
"--voice",
|
||||
"Mia",
|
||||
])
|
||||
.expect("speech command parses");
|
||||
|
||||
let Some(Commands::Speech(args)) = cli.command else {
|
||||
panic!("expected speech command");
|
||||
};
|
||||
assert_eq!(args.text, "hello");
|
||||
assert_eq!(
|
||||
infer_speech_model(args.model.as_deref(), false, false),
|
||||
"mimo-v2.5-tts"
|
||||
);
|
||||
assert_eq!(
|
||||
normalize_speech_format(&args.format).as_deref(),
|
||||
Some("pcm16")
|
||||
);
|
||||
assert_eq!(args.output_dir, Some(PathBuf::from("audio")));
|
||||
assert_eq!(args.voice.as_deref(), Some("Mia"));
|
||||
}
|
||||
}
|
||||
|
||||
/// Test API connectivity by making a minimal request
|
||||
|
||||
@@ -975,12 +975,13 @@ impl ToolRegistryBuilder {
|
||||
plan_state: super::plan::SharedPlanState,
|
||||
) -> Self {
|
||||
let speech_client = client.clone();
|
||||
let speech_output_dir = runtime.speech_output_dir.clone();
|
||||
self.with_agent_tools(allow_shell)
|
||||
.with_todo_tool(todo_list)
|
||||
.with_plan_tool(plan_state)
|
||||
.with_review_tool(client.clone(), model.clone())
|
||||
.with_rlm_tool(client, model)
|
||||
.with_speech_tools(speech_client, None)
|
||||
.with_speech_tools(speech_client, speech_output_dir)
|
||||
.with_recall_archive_tool()
|
||||
.with_subagent_tools(manager, runtime)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use anyhow::Context as _;
|
||||
use async_trait::async_trait;
|
||||
use base64::{Engine as _, engine::general_purpose};
|
||||
use serde_json::{Value, json};
|
||||
@@ -19,23 +20,19 @@ use super::spec::{
|
||||
optional_bool, optional_str, required_str,
|
||||
};
|
||||
|
||||
const DEFAULT_FORMAT: &str = "wav";
|
||||
const DEFAULT_VOICE: &str = "mimo_default";
|
||||
pub(crate) const DEFAULT_FORMAT: &str = "wav";
|
||||
pub(crate) const DEFAULT_VOICE: &str = "mimo_default";
|
||||
const VOICE_CLONE_BASE64_MAX_BYTES: usize = 10 * 1024 * 1024;
|
||||
const SUPPORTED_SPEECH_FORMATS: &[&str] = &["wav", "mp3", "pcm16"];
|
||||
pub(crate) const SUPPORTED_SPEECH_FORMATS: &[&str] = &["wav", "mp3", "pcm16"];
|
||||
|
||||
pub const SUPPORTED_XIAOMI_MIMO_SPEECH_MODELS: &[&str] = &[
|
||||
"mimo-v2.5-pro",
|
||||
"mimo-v2.5",
|
||||
"mimo-v2.5-tts-voiceclone",
|
||||
"mimo-v2.5-tts-voicedesign",
|
||||
"mimo-v2.5-tts",
|
||||
"mimo-v2-pro",
|
||||
"mimo-v2-omni",
|
||||
"mimo-v2-tts",
|
||||
];
|
||||
|
||||
const SPEECH_MODEL_EXAMPLES: &[&str] = &[
|
||||
pub(crate) const SPEECH_MODEL_EXAMPLES: &[&str] = &[
|
||||
"mimo-v2.5-tts",
|
||||
"mimo-v2.5-tts-voicedesign",
|
||||
"mimo-v2.5-tts-voiceclone",
|
||||
@@ -302,7 +299,7 @@ impl ToolSpec for SpeechTool {
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_speech_model(
|
||||
pub(crate) fn infer_speech_model(
|
||||
model: Option<&str>,
|
||||
has_clone_voice: bool,
|
||||
has_voice_prompt: bool,
|
||||
@@ -316,7 +313,7 @@ fn infer_speech_model(
|
||||
}
|
||||
}
|
||||
|
||||
fn combine_speech_instructions(
|
||||
pub(crate) fn combine_speech_instructions(
|
||||
instruction: Option<String>,
|
||||
voice_prompt: Option<String>,
|
||||
) -> Option<String> {
|
||||
@@ -340,7 +337,7 @@ fn combine_speech_instructions(
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_speech_format(format: &str) -> Option<String> {
|
||||
pub(crate) fn normalize_speech_format(format: &str) -> Option<String> {
|
||||
let normalized = format.trim().to_ascii_lowercase();
|
||||
match normalized.as_str() {
|
||||
"wav" | "mp3" | "pcm16" => Some(normalized),
|
||||
@@ -349,7 +346,7 @@ fn normalize_speech_format(format: &str) -> Option<String> {
|
||||
}
|
||||
}
|
||||
|
||||
fn default_speech_output_name(format: &str) -> String {
|
||||
pub(crate) fn default_speech_output_name(format: &str) -> String {
|
||||
format!(
|
||||
"speech.{}",
|
||||
normalize_speech_format(format)
|
||||
@@ -391,12 +388,25 @@ async fn encode_voice_clone_data_uri(path: &Path) -> Result<String, ToolError> {
|
||||
path.display()
|
||||
))
|
||||
})?;
|
||||
|
||||
voice_clone_data_uri_from_bytes(path, &bytes)
|
||||
.map_err(|err| ToolError::invalid_input(err.to_string()))
|
||||
}
|
||||
|
||||
pub(crate) fn encode_voice_clone_sample_data_uri(path: &Path) -> anyhow::Result<String> {
|
||||
let bytes = std::fs::read(path)
|
||||
.with_context(|| format!("Failed to read voice clone sample {}", path.display()))?;
|
||||
|
||||
voice_clone_data_uri_from_bytes(path, &bytes)
|
||||
}
|
||||
|
||||
fn voice_clone_data_uri_from_bytes(path: &Path, bytes: &[u8]) -> anyhow::Result<String> {
|
||||
let base64_audio = general_purpose::STANDARD.encode(bytes);
|
||||
if base64_audio.len() > VOICE_CLONE_BASE64_MAX_BYTES {
|
||||
return Err(ToolError::invalid_input(format!(
|
||||
anyhow::bail!(
|
||||
"voice clone sample is too large after base64 encoding ({} bytes > 10 MB)",
|
||||
base64_audio.len()
|
||||
)));
|
||||
);
|
||||
}
|
||||
|
||||
let extension = path
|
||||
@@ -408,16 +418,14 @@ async fn encode_voice_clone_data_uri(path: &Path) -> Result<String, ToolError> {
|
||||
"mp3" => "audio/mpeg",
|
||||
"wav" => "audio/wav",
|
||||
other => {
|
||||
return Err(ToolError::invalid_input(format!(
|
||||
"unsupported voice clone sample extension '{other}'. Use .mp3 or .wav."
|
||||
)));
|
||||
anyhow::bail!("unsupported voice clone sample extension '{other}'. Use .mp3 or .wav.");
|
||||
}
|
||||
};
|
||||
|
||||
Ok(format!("data:{mime};base64,{base64_audio}"))
|
||||
}
|
||||
|
||||
fn describe_speech_voice(voice: &str) -> String {
|
||||
pub(crate) fn describe_speech_voice(voice: &str) -> String {
|
||||
if voice.starts_with("data:") {
|
||||
"embedded voice clone sample".to_string()
|
||||
} else {
|
||||
@@ -502,6 +510,37 @@ mod tests {
|
||||
assert_eq!(normalize_speech_format("flac"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn supported_xiaomi_mimo_speech_models_are_tts_only() {
|
||||
assert!(
|
||||
SUPPORTED_XIAOMI_MIMO_SPEECH_MODELS
|
||||
.iter()
|
||||
.all(|model| model.to_ascii_lowercase().contains("tts")),
|
||||
"model-visible speech list must not include chat-only MiMo models"
|
||||
);
|
||||
assert!(SUPPORTED_XIAOMI_MIMO_SPEECH_MODELS.contains(&"mimo-v2.5-tts"));
|
||||
assert!(!SUPPORTED_XIAOMI_MIMO_SPEECH_MODELS.contains(&"mimo-v2.5-pro"));
|
||||
assert!(!SUPPORTED_XIAOMI_MIMO_SPEECH_MODELS.contains(&"mimo-v2.5"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn configured_output_dir_is_used_for_default_tool_output() {
|
||||
let tmp = tempfile::tempdir().expect("tempdir");
|
||||
let context = ToolContext::new(tmp.path().to_path_buf());
|
||||
let configured = tmp.path().join("speech-artifacts");
|
||||
|
||||
let output = resolve_speech_output_path(
|
||||
&json!({"text": "hello"}),
|
||||
&context,
|
||||
None,
|
||||
"pcm",
|
||||
Some(&configured),
|
||||
)
|
||||
.expect("output path");
|
||||
|
||||
assert_eq!(output, configured.join("speech.pcm16"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn displays_openai_compatible_base_url() {
|
||||
assert_eq!(
|
||||
|
||||
@@ -800,6 +800,10 @@ pub struct SubAgentRuntime {
|
||||
/// false-timeout the child mid-thinking. `child_runtime()` and
|
||||
/// `background_runtime()` preserve the parent's value (#1806, #1808).
|
||||
pub step_api_timeout: Duration,
|
||||
/// Default directory for Xiaomi MiMo speech/TTS tool outputs inherited by
|
||||
/// child registries. Keeps parent and sub-agent `speech` / `tts` tools on
|
||||
/// the same `[speech].output_dir` / env override.
|
||||
pub speech_output_dir: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl SubAgentRuntime {
|
||||
@@ -835,6 +839,7 @@ impl SubAgentRuntime {
|
||||
fork_context: None,
|
||||
mcp_pool: None,
|
||||
step_api_timeout: DEFAULT_STEP_API_TIMEOUT,
|
||||
speech_output_dir: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -858,6 +863,13 @@ impl SubAgentRuntime {
|
||||
self
|
||||
}
|
||||
|
||||
/// Preserve the configured speech output directory for sub-agent tools.
|
||||
#[must_use]
|
||||
pub fn with_speech_output_dir(mut self, output_dir: Option<PathBuf>) -> Self {
|
||||
self.speech_output_dir = output_dir;
|
||||
self
|
||||
}
|
||||
|
||||
/// Attach the wakeup channel so the engine's parent turn loop can resume
|
||||
/// when this runtime's direct children finish (issue #756). The channel
|
||||
/// is propagated to descendants via clone, but only `spawn_depth == 1`
|
||||
@@ -980,6 +992,7 @@ impl SubAgentRuntime {
|
||||
fork_context: self.fork_context.clone(),
|
||||
mcp_pool: self.mcp_pool.clone(),
|
||||
step_api_timeout: self.step_api_timeout,
|
||||
speech_output_dir: self.speech_output_dir.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1805,6 +1805,7 @@ fn stub_runtime() -> SubAgentRuntime {
|
||||
fork_context: None,
|
||||
mcp_pool: None,
|
||||
step_api_timeout: DEFAULT_STEP_API_TIMEOUT,
|
||||
speech_output_dir: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2036,6 +2037,16 @@ fn emit_parent_completion_fires_for_direct_child() {
|
||||
assert!(rx.try_recv().is_err(), "should be exactly one message");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn child_runtime_inherits_speech_output_dir() {
|
||||
let output_dir = PathBuf::from("configured-speech-output");
|
||||
let runtime = stub_runtime().with_speech_output_dir(Some(output_dir.clone()));
|
||||
|
||||
let child = runtime.child_runtime();
|
||||
|
||||
assert_eq!(child.speech_output_dir, Some(output_dir));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn emit_parent_completion_skips_grandchildren() {
|
||||
let (tx, mut rx) = mpsc::unbounded_channel::<SubAgentCompletion>();
|
||||
|
||||
Reference in New Issue
Block a user