hello world
This commit is contained in:
commit
c99507ca1e
84 changed files with 54252 additions and 0 deletions
31
src-rust/crates/core/Cargo.toml
Normal file
31
src-rust/crates/core/Cargo.toml
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
[package]
|
||||
name = "cc-core"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
tokio = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
toml = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
dirs = { workspace = true }
|
||||
once_cell = { workspace = true }
|
||||
parking_lot = { workspace = true }
|
||||
dashmap = { workspace = true }
|
||||
indexmap = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
base64 = { workspace = true }
|
||||
sha2 = { workspace = true }
|
||||
hex = { workspace = true }
|
||||
url = { workspace = true }
|
||||
schemars = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
urlencoding = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = { workspace = true }
|
||||
403
src-rust/crates/core/src/analytics.rs
Normal file
403
src-rust/crates/core/src/analytics.rs
Normal file
|
|
@ -0,0 +1,403 @@
|
|||
//! Analytics and telemetry (OpenTelemetry-compatible counters)
|
||||
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Session-level metrics counters (mirrors TypeScript bootstrap state).
|
||||
///
|
||||
/// All counters use `AtomicU64` so they can be shared across threads without
|
||||
/// a mutex. Cost is stored as integer millicents (cost_usd × 100_000) to
|
||||
/// avoid floating-point atomic arithmetic.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct SessionMetrics {
|
||||
/// Total cost in units of 1/100_000 USD (i.e. millicents).
|
||||
pub total_cost_usd_millicents: AtomicU64,
|
||||
pub total_input_tokens: AtomicU64,
|
||||
pub total_output_tokens: AtomicU64,
|
||||
pub total_api_duration_ms: AtomicU64,
|
||||
pub total_tool_duration_ms: AtomicU64,
|
||||
pub total_lines_added: AtomicU64,
|
||||
pub total_lines_removed: AtomicU64,
|
||||
pub session_count: AtomicU64,
|
||||
pub commit_count: AtomicU64,
|
||||
pub pr_count: AtomicU64,
|
||||
pub tool_use_count: AtomicU64,
|
||||
}
|
||||
|
||||
impl SessionMetrics {
|
||||
pub fn new() -> Arc<Self> {
|
||||
Arc::new(Self::default())
|
||||
}
|
||||
|
||||
pub fn add_cost(&self, usd: f64) {
|
||||
let millicents = (usd * 100_000.0) as u64;
|
||||
self.total_cost_usd_millicents
|
||||
.fetch_add(millicents, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn total_cost_usd(&self) -> f64 {
|
||||
self.total_cost_usd_millicents.load(Ordering::Relaxed) as f64 / 100_000.0
|
||||
}
|
||||
|
||||
pub fn add_tokens(&self, input: u32, output: u32) {
|
||||
self.total_input_tokens
|
||||
.fetch_add(input as u64, Ordering::Relaxed);
|
||||
self.total_output_tokens
|
||||
.fetch_add(output as u64, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn add_api_duration(&self, ms: u64) {
|
||||
self.total_api_duration_ms.fetch_add(ms, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn add_tool_duration(&self, ms: u64) {
|
||||
self.total_tool_duration_ms.fetch_add(ms, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn add_lines(&self, added: i64, removed: i64) {
|
||||
if added > 0 {
|
||||
self.total_lines_added
|
||||
.fetch_add(added as u64, Ordering::Relaxed);
|
||||
}
|
||||
if removed > 0 {
|
||||
self.total_lines_removed
|
||||
.fetch_add(removed as u64, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn increment_commits(&self) {
|
||||
self.commit_count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn increment_prs(&self) {
|
||||
self.pr_count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn increment_tool_use(&self) {
|
||||
self.tool_use_count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn summary(&self) -> MetricsSummary {
|
||||
MetricsSummary {
|
||||
cost_usd: self.total_cost_usd(),
|
||||
input_tokens: self.total_input_tokens.load(Ordering::Relaxed),
|
||||
output_tokens: self.total_output_tokens.load(Ordering::Relaxed),
|
||||
api_duration_ms: self.total_api_duration_ms.load(Ordering::Relaxed),
|
||||
tool_duration_ms: self.total_tool_duration_ms.load(Ordering::Relaxed),
|
||||
lines_added: self.total_lines_added.load(Ordering::Relaxed),
|
||||
lines_removed: self.total_lines_removed.load(Ordering::Relaxed),
|
||||
commits: self.commit_count.load(Ordering::Relaxed),
|
||||
prs: self.pr_count.load(Ordering::Relaxed),
|
||||
tool_uses: self.tool_use_count.load(Ordering::Relaxed),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A point-in-time snapshot of session metrics.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MetricsSummary {
|
||||
pub cost_usd: f64,
|
||||
pub input_tokens: u64,
|
||||
pub output_tokens: u64,
|
||||
pub api_duration_ms: u64,
|
||||
pub tool_duration_ms: u64,
|
||||
pub lines_added: u64,
|
||||
pub lines_removed: u64,
|
||||
pub commits: u64,
|
||||
pub prs: u64,
|
||||
pub tool_uses: u64,
|
||||
}
|
||||
|
||||
impl MetricsSummary {
|
||||
/// Format cost as a dollar amount string with appropriate precision.
|
||||
pub fn format_cost(&self) -> String {
|
||||
if self.cost_usd < 0.01 {
|
||||
format!("${:.5}", self.cost_usd)
|
||||
} else {
|
||||
format!("${:.4}", self.cost_usd)
|
||||
}
|
||||
}
|
||||
|
||||
/// Format total token count with K/M suffix.
|
||||
pub fn format_tokens(&self) -> String {
|
||||
let total = self.input_tokens + self.output_tokens;
|
||||
if total >= 1_000_000 {
|
||||
format!("{:.1}M tok", total as f64 / 1_000_000.0)
|
||||
} else if total >= 1_000 {
|
||||
format!("{:.1}K tok", total as f64 / 1_000.0)
|
||||
} else {
|
||||
format!("{} tok", total)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Event types for first-party analytics (privacy-respecting — no PII).
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AnalyticsEvent {
|
||||
SessionStarted {
|
||||
model: String,
|
||||
is_interactive: bool,
|
||||
},
|
||||
SessionEnded {
|
||||
turn_count: u32,
|
||||
cost_usd: f64,
|
||||
duration_ms: u64,
|
||||
had_errors: bool,
|
||||
},
|
||||
ToolUsed {
|
||||
tool_name: String,
|
||||
success: bool,
|
||||
duration_ms: u64,
|
||||
},
|
||||
CommandExecuted {
|
||||
command: String,
|
||||
success: bool,
|
||||
},
|
||||
CompactionTriggered {
|
||||
tokens_before: u32,
|
||||
tokens_after: u32,
|
||||
},
|
||||
}
|
||||
|
||||
/// Analytics sink — currently logs via `tracing`; can be extended to push
|
||||
/// events to a first-party endpoint.
|
||||
pub struct Analytics {
|
||||
enabled: bool,
|
||||
session_id: String,
|
||||
}
|
||||
|
||||
impl Analytics {
|
||||
pub fn new(session_id: String, enabled: bool) -> Self {
|
||||
Self {
|
||||
enabled,
|
||||
session_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn track(&self, event: AnalyticsEvent) {
|
||||
if !self.enabled {
|
||||
return;
|
||||
}
|
||||
tracing::debug!(
|
||||
session_id = %self.session_id,
|
||||
event = ?event,
|
||||
"analytics event"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_session_metrics_initial_zero() {
|
||||
let m = SessionMetrics::new();
|
||||
assert_eq!(m.total_cost_usd(), 0.0);
|
||||
assert_eq!(m.total_input_tokens.load(Ordering::Relaxed), 0);
|
||||
assert_eq!(m.total_output_tokens.load(Ordering::Relaxed), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_cost_single() {
|
||||
let m = SessionMetrics::new();
|
||||
m.add_cost(0.01);
|
||||
let cost = m.total_cost_usd();
|
||||
// Allow small floating-point tolerance
|
||||
assert!((cost - 0.01).abs() < 1e-9, "cost = {}", cost);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_cost_accumulates() {
|
||||
let m = SessionMetrics::new();
|
||||
m.add_cost(1.0);
|
||||
m.add_cost(2.5);
|
||||
let cost = m.total_cost_usd();
|
||||
assert!((cost - 3.5).abs() < 1e-9, "cost = {}", cost);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_tokens() {
|
||||
let m = SessionMetrics::new();
|
||||
m.add_tokens(1000, 500);
|
||||
assert_eq!(m.total_input_tokens.load(Ordering::Relaxed), 1000);
|
||||
assert_eq!(m.total_output_tokens.load(Ordering::Relaxed), 500);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_tokens_accumulates() {
|
||||
let m = SessionMetrics::new();
|
||||
m.add_tokens(1000, 500);
|
||||
m.add_tokens(200, 100);
|
||||
assert_eq!(m.total_input_tokens.load(Ordering::Relaxed), 1200);
|
||||
assert_eq!(m.total_output_tokens.load(Ordering::Relaxed), 600);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_lines_positive() {
|
||||
let m = SessionMetrics::new();
|
||||
m.add_lines(10, 5);
|
||||
assert_eq!(m.total_lines_added.load(Ordering::Relaxed), 10);
|
||||
assert_eq!(m.total_lines_removed.load(Ordering::Relaxed), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_lines_negative_ignored() {
|
||||
let m = SessionMetrics::new();
|
||||
m.add_lines(-3, -7);
|
||||
assert_eq!(m.total_lines_added.load(Ordering::Relaxed), 0);
|
||||
assert_eq!(m.total_lines_removed.load(Ordering::Relaxed), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_increment_commits_and_prs() {
|
||||
let m = SessionMetrics::new();
|
||||
m.increment_commits();
|
||||
m.increment_commits();
|
||||
m.increment_prs();
|
||||
assert_eq!(m.commit_count.load(Ordering::Relaxed), 2);
|
||||
assert_eq!(m.pr_count.load(Ordering::Relaxed), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_increment_tool_use() {
|
||||
let m = SessionMetrics::new();
|
||||
for _ in 0..5 {
|
||||
m.increment_tool_use();
|
||||
}
|
||||
assert_eq!(m.tool_use_count.load(Ordering::Relaxed), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_summary_snapshot() {
|
||||
let m = SessionMetrics::new();
|
||||
m.add_cost(1.23456);
|
||||
m.add_tokens(100, 50);
|
||||
m.add_api_duration(300);
|
||||
m.add_tool_duration(150);
|
||||
m.add_lines(8, 3);
|
||||
m.increment_commits();
|
||||
m.increment_prs();
|
||||
m.increment_tool_use();
|
||||
|
||||
let s = m.summary();
|
||||
assert!((s.cost_usd - 1.23456).abs() < 1e-9);
|
||||
assert_eq!(s.input_tokens, 100);
|
||||
assert_eq!(s.output_tokens, 50);
|
||||
assert_eq!(s.api_duration_ms, 300);
|
||||
assert_eq!(s.tool_duration_ms, 150);
|
||||
assert_eq!(s.lines_added, 8);
|
||||
assert_eq!(s.lines_removed, 3);
|
||||
assert_eq!(s.commits, 1);
|
||||
assert_eq!(s.prs, 1);
|
||||
assert_eq!(s.tool_uses, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_cost_small() {
|
||||
let s = MetricsSummary {
|
||||
cost_usd: 0.001,
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
api_duration_ms: 0,
|
||||
tool_duration_ms: 0,
|
||||
lines_added: 0,
|
||||
lines_removed: 0,
|
||||
commits: 0,
|
||||
prs: 0,
|
||||
tool_uses: 0,
|
||||
};
|
||||
let formatted = s.format_cost();
|
||||
assert!(formatted.starts_with('$'));
|
||||
// Should have 5 decimal places for small cost
|
||||
assert!(formatted.contains('.'));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_cost_large() {
|
||||
let s = MetricsSummary {
|
||||
cost_usd: 1.5,
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
api_duration_ms: 0,
|
||||
tool_duration_ms: 0,
|
||||
lines_added: 0,
|
||||
lines_removed: 0,
|
||||
commits: 0,
|
||||
prs: 0,
|
||||
tool_uses: 0,
|
||||
};
|
||||
assert_eq!(s.format_cost(), "$1.5000");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_tokens_exact() {
|
||||
let s = MetricsSummary {
|
||||
cost_usd: 0.0,
|
||||
input_tokens: 500,
|
||||
output_tokens: 300,
|
||||
api_duration_ms: 0,
|
||||
tool_duration_ms: 0,
|
||||
lines_added: 0,
|
||||
lines_removed: 0,
|
||||
commits: 0,
|
||||
prs: 0,
|
||||
tool_uses: 0,
|
||||
};
|
||||
assert_eq!(s.format_tokens(), "800 tok");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_tokens_kilo() {
|
||||
let s = MetricsSummary {
|
||||
cost_usd: 0.0,
|
||||
input_tokens: 5_000,
|
||||
output_tokens: 3_000,
|
||||
api_duration_ms: 0,
|
||||
tool_duration_ms: 0,
|
||||
lines_added: 0,
|
||||
lines_removed: 0,
|
||||
commits: 0,
|
||||
prs: 0,
|
||||
tool_uses: 0,
|
||||
};
|
||||
assert!(s.format_tokens().ends_with("K tok"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_tokens_mega() {
|
||||
let s = MetricsSummary {
|
||||
cost_usd: 0.0,
|
||||
input_tokens: 1_500_000,
|
||||
output_tokens: 500_000,
|
||||
api_duration_ms: 0,
|
||||
tool_duration_ms: 0,
|
||||
lines_added: 0,
|
||||
lines_removed: 0,
|
||||
commits: 0,
|
||||
prs: 0,
|
||||
tool_uses: 0,
|
||||
};
|
||||
assert!(s.format_tokens().ends_with("M tok"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_analytics_track_disabled_no_panic() {
|
||||
let a = Analytics::new("test-session".to_string(), false);
|
||||
// Should not panic even though disabled
|
||||
a.track(AnalyticsEvent::SessionStarted {
|
||||
model: "claude-opus-4-6".to_string(),
|
||||
is_interactive: true,
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_analytics_track_enabled_no_panic() {
|
||||
let a = Analytics::new("test-session-2".to_string(), true);
|
||||
a.track(AnalyticsEvent::ToolUsed {
|
||||
tool_name: "Bash".to_string(),
|
||||
success: true,
|
||||
duration_ms: 42,
|
||||
});
|
||||
}
|
||||
}
|
||||
423
src-rust/crates/core/src/keybindings.rs
Normal file
423
src-rust/crates/core/src/keybindings.rs
Normal file
|
|
@ -0,0 +1,423 @@
|
|||
//! Configurable keyboard shortcuts system
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
||||
/// All keybinding contexts
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "PascalCase")]
|
||||
pub enum KeyContext {
|
||||
Global,
|
||||
Chat,
|
||||
Autocomplete,
|
||||
Confirmation,
|
||||
Help,
|
||||
Transcript,
|
||||
HistorySearch,
|
||||
Task,
|
||||
ThemePicker,
|
||||
Settings,
|
||||
Tabs,
|
||||
Attachments,
|
||||
Footer,
|
||||
MessageSelector,
|
||||
DiffDialog,
|
||||
ModelPicker,
|
||||
Select,
|
||||
Plugin,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ParsedKeystroke {
|
||||
pub key: String, // normalized key name
|
||||
pub ctrl: bool,
|
||||
pub alt: bool,
|
||||
pub shift: bool,
|
||||
pub meta: bool,
|
||||
}
|
||||
|
||||
pub type Chord = Vec<ParsedKeystroke>;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ParsedBinding {
|
||||
pub chord: Chord,
|
||||
pub action: Option<String>, // None = unbound
|
||||
pub context: KeyContext,
|
||||
}
|
||||
|
||||
/// Parse a keystroke string like "ctrl+shift+enter" into ParsedKeystroke
|
||||
pub fn parse_keystroke(s: &str) -> Option<ParsedKeystroke> {
|
||||
let s = s.trim().to_lowercase();
|
||||
let mut ctrl = false;
|
||||
let mut alt = false;
|
||||
let mut shift = false;
|
||||
let mut meta = false;
|
||||
let mut key_parts: Vec<&str> = Vec::new();
|
||||
|
||||
for part in s.split('+') {
|
||||
let part = part.trim();
|
||||
match part {
|
||||
"ctrl" | "control" => ctrl = true,
|
||||
"alt" | "opt" | "option" => alt = true,
|
||||
"shift" => shift = true,
|
||||
"meta" | "cmd" | "command" | "super" | "win" => meta = true,
|
||||
_ => key_parts.push(part),
|
||||
}
|
||||
}
|
||||
|
||||
if key_parts.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let key = normalize_key(key_parts.join("+").as_str());
|
||||
Some(ParsedKeystroke {
|
||||
key,
|
||||
ctrl,
|
||||
alt,
|
||||
shift,
|
||||
meta,
|
||||
})
|
||||
}
|
||||
|
||||
fn normalize_key(k: &str) -> String {
|
||||
match k {
|
||||
"esc" | "escape" => "escape".to_string(),
|
||||
"return" | "enter" => "enter".to_string(),
|
||||
"del" | "delete" => "delete".to_string(),
|
||||
"backspace" | "bs" => "backspace".to_string(),
|
||||
"space" | " " => "space".to_string(),
|
||||
"up" => "up".to_string(),
|
||||
"down" => "down".to_string(),
|
||||
"left" => "left".to_string(),
|
||||
"right" => "right".to_string(),
|
||||
"pageup" | "pgup" => "pageup".to_string(),
|
||||
"pagedown" | "pgdn" | "pgdown" => "pagedown".to_string(),
|
||||
"home" => "home".to_string(),
|
||||
"end" => "end".to_string(),
|
||||
"tab" => "tab".to_string(),
|
||||
k => k.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a chord (space-separated keystrokes like "ctrl+k ctrl+d")
|
||||
pub fn parse_chord(s: &str) -> Option<Chord> {
|
||||
let keystrokes: Vec<ParsedKeystroke> =
|
||||
s.split_whitespace().filter_map(parse_keystroke).collect();
|
||||
if keystrokes.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(keystrokes)
|
||||
}
|
||||
}
|
||||
|
||||
/// Keys that cannot be rebound
|
||||
pub const NON_REBINDABLE: &[&str] = &["ctrl+c", "ctrl+d", "ctrl+m"];
|
||||
|
||||
/// Default keybindings
|
||||
pub fn default_bindings() -> Vec<ParsedBinding> {
|
||||
let defaults: &[(&str, &str, KeyContext)] = &[
|
||||
// Global
|
||||
("ctrl+c", "interrupt", KeyContext::Global),
|
||||
("ctrl+d", "exit", KeyContext::Global),
|
||||
("ctrl+l", "redraw", KeyContext::Global),
|
||||
("ctrl+r", "historySearch", KeyContext::Global),
|
||||
// Chat
|
||||
("enter", "submit", KeyContext::Chat),
|
||||
("up", "historyPrev", KeyContext::Chat),
|
||||
("down", "historyNext", KeyContext::Chat),
|
||||
("shift+tab", "cycleMode", KeyContext::Chat),
|
||||
("pageup", "scrollUp", KeyContext::Chat),
|
||||
("pagedown", "scrollDown", KeyContext::Chat),
|
||||
// Confirmation
|
||||
("y", "yes", KeyContext::Confirmation),
|
||||
("enter", "yes", KeyContext::Confirmation),
|
||||
("n", "no", KeyContext::Confirmation),
|
||||
("escape", "no", KeyContext::Confirmation),
|
||||
("up", "prevOption", KeyContext::Confirmation),
|
||||
("down", "nextOption", KeyContext::Confirmation),
|
||||
// Help
|
||||
("escape", "close", KeyContext::Help),
|
||||
("q", "close", KeyContext::Help),
|
||||
// HistorySearch
|
||||
("enter", "select", KeyContext::HistorySearch),
|
||||
("escape", "cancel", KeyContext::HistorySearch),
|
||||
("up", "prevResult", KeyContext::HistorySearch),
|
||||
("down", "nextResult", KeyContext::HistorySearch),
|
||||
];
|
||||
|
||||
defaults
|
||||
.iter()
|
||||
.filter_map(|(chord_str, action, context)| {
|
||||
parse_chord(chord_str).map(|chord| ParsedBinding {
|
||||
chord,
|
||||
action: Some(action.to_string()),
|
||||
context: context.clone(),
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// User keybindings loaded from ~/.claude/keybindings.json
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct UserKeybindings {
|
||||
pub bindings: Vec<UserBinding>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct UserBinding {
|
||||
pub chord: String, // e.g. "ctrl+k ctrl+d"
|
||||
pub action: Option<String>, // None = unbound
|
||||
pub context: Option<String>,
|
||||
}
|
||||
|
||||
impl UserKeybindings {
|
||||
pub fn load(config_dir: &Path) -> Self {
|
||||
let path = config_dir.join("keybindings.json");
|
||||
if let Ok(content) = std::fs::read_to_string(&path) {
|
||||
serde_json::from_str(&content).unwrap_or_default()
|
||||
} else {
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn save(&self, config_dir: &Path) -> anyhow::Result<()> {
|
||||
let path = config_dir.join("keybindings.json");
|
||||
let json = serde_json::to_string_pretty(self)?;
|
||||
std::fs::write(path, json)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolved keybindings (defaults merged with user overrides)
|
||||
pub struct KeybindingResolver {
|
||||
bindings: Vec<ParsedBinding>,
|
||||
pending_chord: Vec<ParsedKeystroke>,
|
||||
}
|
||||
|
||||
impl KeybindingResolver {
|
||||
pub fn new(user: &UserKeybindings) -> Self {
|
||||
let mut bindings = default_bindings();
|
||||
|
||||
// Apply user overrides (user bindings win, last match wins)
|
||||
for user_binding in &user.bindings {
|
||||
if let Some(chord) = parse_chord(&user_binding.chord) {
|
||||
let context = user_binding
|
||||
.context
|
||||
.as_deref()
|
||||
.and_then(|c| serde_json::from_str(&format!("\"{}\"", c)).ok())
|
||||
.unwrap_or(KeyContext::Global);
|
||||
|
||||
bindings.push(ParsedBinding {
|
||||
chord,
|
||||
action: user_binding.action.clone(),
|
||||
context,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
bindings,
|
||||
pending_chord: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a keystroke, returns action if binding matches
|
||||
pub fn process(
|
||||
&mut self,
|
||||
keystroke: ParsedKeystroke,
|
||||
context: &KeyContext,
|
||||
) -> KeybindingResult {
|
||||
self.pending_chord.push(keystroke);
|
||||
|
||||
// Find matching bindings in current context + Global
|
||||
let matches: Vec<&ParsedBinding> = self
|
||||
.bindings
|
||||
.iter()
|
||||
.filter(|b| &b.context == context || b.context == KeyContext::Global)
|
||||
.filter(|b| b.chord.starts_with(self.pending_chord.as_slice()))
|
||||
.collect();
|
||||
|
||||
if matches.is_empty() {
|
||||
self.pending_chord.clear();
|
||||
return KeybindingResult::NoMatch;
|
||||
}
|
||||
|
||||
let exact: Vec<&ParsedBinding> = matches
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|b| b.chord.len() == self.pending_chord.len())
|
||||
.collect();
|
||||
|
||||
if !exact.is_empty() {
|
||||
// Last match wins (user overrides)
|
||||
let binding = exact.last().unwrap();
|
||||
self.pending_chord.clear();
|
||||
return match &binding.action {
|
||||
Some(action) => KeybindingResult::Action(action.clone()),
|
||||
None => KeybindingResult::Unbound,
|
||||
};
|
||||
}
|
||||
|
||||
// Chord in progress
|
||||
KeybindingResult::Pending
|
||||
}
|
||||
|
||||
pub fn cancel_chord(&mut self) {
|
||||
self.pending_chord.clear();
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for ParsedKeystroke {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.key == other.key
|
||||
&& self.ctrl == other.ctrl
|
||||
&& self.alt == other.alt
|
||||
&& self.shift == other.shift
|
||||
&& self.meta == other.meta
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum KeybindingResult {
|
||||
Action(String),
|
||||
Unbound,
|
||||
Pending,
|
||||
NoMatch,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_keystroke_simple() {
|
||||
let ks = parse_keystroke("enter").unwrap();
|
||||
assert_eq!(ks.key, "enter");
|
||||
assert!(!ks.ctrl);
|
||||
assert!(!ks.alt);
|
||||
assert!(!ks.shift);
|
||||
assert!(!ks.meta);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_keystroke_ctrl_c() {
|
||||
let ks = parse_keystroke("ctrl+c").unwrap();
|
||||
assert_eq!(ks.key, "c");
|
||||
assert!(ks.ctrl);
|
||||
assert!(!ks.alt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_keystroke_ctrl_shift_enter() {
|
||||
let ks = parse_keystroke("ctrl+shift+enter").unwrap();
|
||||
assert_eq!(ks.key, "enter");
|
||||
assert!(ks.ctrl);
|
||||
assert!(ks.shift);
|
||||
assert!(!ks.alt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_keystroke_normalizes_esc() {
|
||||
let ks = parse_keystroke("esc").unwrap();
|
||||
assert_eq!(ks.key, "escape");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_keystroke_normalizes_return() {
|
||||
let ks = parse_keystroke("return").unwrap();
|
||||
assert_eq!(ks.key, "enter");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_keystroke_empty_returns_none() {
|
||||
assert!(parse_keystroke("ctrl+").is_none());
|
||||
assert!(parse_keystroke("").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_chord_single() {
|
||||
let chord = parse_chord("ctrl+c").unwrap();
|
||||
assert_eq!(chord.len(), 1);
|
||||
assert_eq!(chord[0].key, "c");
|
||||
assert!(chord[0].ctrl);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_chord_multi() {
|
||||
let chord = parse_chord("ctrl+k ctrl+d").unwrap();
|
||||
assert_eq!(chord.len(), 2);
|
||||
assert_eq!(chord[0].key, "k");
|
||||
assert_eq!(chord[1].key, "d");
|
||||
assert!(chord[0].ctrl);
|
||||
assert!(chord[1].ctrl);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_chord_empty_returns_none() {
|
||||
assert!(parse_chord("").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_bindings_not_empty() {
|
||||
let bindings = default_bindings();
|
||||
assert!(!bindings.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_bindings_contains_ctrl_c() {
|
||||
let bindings = default_bindings();
|
||||
let ctrl_c = bindings.iter().find(|b| {
|
||||
b.chord.len() == 1
|
||||
&& b.chord[0].ctrl
|
||||
&& b.chord[0].key == "c"
|
||||
&& b.context == KeyContext::Global
|
||||
});
|
||||
assert!(ctrl_c.is_some());
|
||||
assert_eq!(ctrl_c.unwrap().action.as_deref(), Some("interrupt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolver_simple_action() {
|
||||
let user = UserKeybindings::default();
|
||||
let mut resolver = KeybindingResolver::new(&user);
|
||||
let ks = parse_keystroke("ctrl+c").unwrap();
|
||||
let result = resolver.process(ks, &KeyContext::Global);
|
||||
assert!(matches!(result, KeybindingResult::Action(ref a) if a == "interrupt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolver_no_match() {
|
||||
let user = UserKeybindings::default();
|
||||
let mut resolver = KeybindingResolver::new(&user);
|
||||
// ctrl+z has no default binding
|
||||
let ks = parse_keystroke("ctrl+z").unwrap();
|
||||
let result = resolver.process(ks, &KeyContext::Chat);
|
||||
assert!(matches!(result, KeybindingResult::NoMatch));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolver_context_match_global_from_chat() {
|
||||
let user = UserKeybindings::default();
|
||||
let mut resolver = KeybindingResolver::new(&user);
|
||||
// ctrl+l is Global, should match even when context is Chat
|
||||
let ks = parse_keystroke("ctrl+l").unwrap();
|
||||
let result = resolver.process(ks, &KeyContext::Chat);
|
||||
assert!(matches!(result, KeybindingResult::Action(ref a) if a == "redraw"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_keystroke_equality() {
|
||||
let ks1 = parse_keystroke("ctrl+enter").unwrap();
|
||||
let ks2 = parse_keystroke("ctrl+enter").unwrap();
|
||||
let ks3 = parse_keystroke("shift+enter").unwrap();
|
||||
assert_eq!(ks1, ks2);
|
||||
assert_ne!(ks1, ks3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_user_keybindings_default_empty() {
|
||||
let user = UserKeybindings::default();
|
||||
assert!(user.bindings.is_empty());
|
||||
}
|
||||
}
|
||||
2011
src-rust/crates/core/src/lib.rs
Normal file
2011
src-rust/crates/core/src/lib.rs
Normal file
File diff suppressed because it is too large
Load diff
294
src-rust/crates/core/src/lsp.rs
Normal file
294
src-rust/crates/core/src/lsp.rs
Normal file
|
|
@ -0,0 +1,294 @@
|
|||
//! Language Server Protocol client stub.
|
||||
//!
|
||||
//! The full LSP implementation is provided by plugins; this module defines
|
||||
//! the integration interface that the rest of the codebase uses to query
|
||||
//! diagnostics, register servers, and format output.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Configuration for a single LSP server process.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LspServerConfig {
|
||||
/// Display name, e.g. "rust-analyzer"
|
||||
pub name: String,
|
||||
/// Path or name of the server binary, e.g. "rust-analyzer"
|
||||
pub command: String,
|
||||
/// Command-line arguments passed to the server binary
|
||||
pub args: Vec<String>,
|
||||
/// Glob patterns that activate this server, e.g. `["*.rs", "*.toml"]`
|
||||
pub file_patterns: Vec<String>,
|
||||
/// Optional server-specific initialization options (passed in LSP `initialize`)
|
||||
pub initialization_options: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// A single diagnostic emitted by an LSP server.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LspDiagnostic {
|
||||
/// Workspace-relative or absolute file path
|
||||
pub file: String,
|
||||
/// 1-based line number
|
||||
pub line: u32,
|
||||
/// 1-based column number
|
||||
pub column: u32,
|
||||
pub severity: DiagnosticSeverity,
|
||||
pub message: String,
|
||||
/// The LSP server that produced this diagnostic (e.g. "rust-analyzer")
|
||||
pub source: Option<String>,
|
||||
/// Diagnostic code (e.g. "E0308"), if provided by the server
|
||||
pub code: Option<String>,
|
||||
}
|
||||
|
||||
/// Severity level of a diagnostic, matching the LSP spec.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub enum DiagnosticSeverity {
|
||||
Error = 1,
|
||||
Warning = 2,
|
||||
Information = 3,
|
||||
Hint = 4,
|
||||
}
|
||||
|
||||
impl DiagnosticSeverity {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Error => "error",
|
||||
Self::Warning => "warning",
|
||||
Self::Information => "info",
|
||||
Self::Hint => "hint",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// LSP manager stub.
|
||||
///
|
||||
/// In the full implementation this will own LSP server processes and route
|
||||
/// JSON-RPC messages. For now it is a registry that tracks configured
|
||||
/// servers and returns empty diagnostic lists — the plugin system is
|
||||
/// responsible for wiring up real communication.
|
||||
pub struct LspManager {
|
||||
servers: Vec<LspServerConfig>,
|
||||
}
|
||||
|
||||
impl LspManager {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
servers: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register an LSP server configuration.
|
||||
pub fn register_server(&mut self, config: LspServerConfig) {
|
||||
self.servers.push(config);
|
||||
}
|
||||
|
||||
/// Return all registered server configurations.
|
||||
pub fn servers(&self) -> &[LspServerConfig] {
|
||||
&self.servers
|
||||
}
|
||||
|
||||
/// Look up a server configuration by name.
|
||||
pub fn server_by_name(&self, name: &str) -> Option<&LspServerConfig> {
|
||||
self.servers.iter().find(|s| s.name == name)
|
||||
}
|
||||
|
||||
/// Get diagnostics for a file.
|
||||
///
|
||||
/// This stub always returns an empty list. When an LSP plugin connects it
|
||||
/// will replace this path with real RPC calls.
|
||||
pub async fn get_diagnostics(&self, _file: &str) -> Vec<LspDiagnostic> {
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
/// Format a slice of diagnostics into a human-readable multi-line string
|
||||
/// suitable for inclusion in tool output or TUI display.
|
||||
pub fn format_diagnostics(diagnostics: &[LspDiagnostic]) -> String {
|
||||
if diagnostics.is_empty() {
|
||||
return "No diagnostics.".to_string();
|
||||
}
|
||||
diagnostics
|
||||
.iter()
|
||||
.map(|d| {
|
||||
format!(
|
||||
"[{}] {}:{}:{} - {}{}{}",
|
||||
d.severity.as_str().to_uppercase(),
|
||||
d.file,
|
||||
d.line,
|
||||
d.column,
|
||||
d.message,
|
||||
d.source
|
||||
.as_deref()
|
||||
.map(|s| format!(" ({})", s))
|
||||
.unwrap_or_default(),
|
||||
d.code
|
||||
.as_deref()
|
||||
.map(|c| format!(" [{}]", c))
|
||||
.unwrap_or_default(),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LspManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_config(name: &str) -> LspServerConfig {
|
||||
LspServerConfig {
|
||||
name: name.to_string(),
|
||||
command: name.to_string(),
|
||||
args: vec![],
|
||||
file_patterns: vec!["*.rs".to_string()],
|
||||
initialization_options: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn make_diagnostic(
|
||||
file: &str,
|
||||
line: u32,
|
||||
col: u32,
|
||||
severity: DiagnosticSeverity,
|
||||
message: &str,
|
||||
) -> LspDiagnostic {
|
||||
LspDiagnostic {
|
||||
file: file.to_string(),
|
||||
line,
|
||||
column: col,
|
||||
severity,
|
||||
message: message.to_string(),
|
||||
source: None,
|
||||
code: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_new_manager_empty() {
|
||||
let mgr = LspManager::new();
|
||||
assert!(mgr.servers().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_register_server() {
|
||||
let mut mgr = LspManager::new();
|
||||
mgr.register_server(make_config("rust-analyzer"));
|
||||
assert_eq!(mgr.servers().len(), 1);
|
||||
assert_eq!(mgr.servers()[0].name, "rust-analyzer");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_register_multiple_servers() {
|
||||
let mut mgr = LspManager::new();
|
||||
mgr.register_server(make_config("rust-analyzer"));
|
||||
mgr.register_server(make_config("pyright"));
|
||||
assert_eq!(mgr.servers().len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_server_by_name_found() {
|
||||
let mut mgr = LspManager::new();
|
||||
mgr.register_server(make_config("rust-analyzer"));
|
||||
mgr.register_server(make_config("pyright"));
|
||||
let s = mgr.server_by_name("pyright");
|
||||
assert!(s.is_some());
|
||||
assert_eq!(s.unwrap().name, "pyright");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_server_by_name_not_found() {
|
||||
let mgr = LspManager::new();
|
||||
assert!(mgr.server_by_name("missing").is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_diagnostics_stub_empty() {
|
||||
let mgr = LspManager::new();
|
||||
let diags = mgr.get_diagnostics("src/main.rs").await;
|
||||
assert!(diags.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_diagnostics_empty() {
|
||||
let result = LspManager::format_diagnostics(&[]);
|
||||
assert_eq!(result, "No diagnostics.");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_diagnostics_single_error() {
|
||||
let diags = vec![make_diagnostic(
|
||||
"src/lib.rs",
|
||||
10,
|
||||
5,
|
||||
DiagnosticSeverity::Error,
|
||||
"type mismatch",
|
||||
)];
|
||||
let result = LspManager::format_diagnostics(&diags);
|
||||
assert!(result.contains("[ERROR]"));
|
||||
assert!(result.contains("src/lib.rs"));
|
||||
assert!(result.contains("10:5"));
|
||||
assert!(result.contains("type mismatch"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_diagnostics_multiple() {
|
||||
let diags = vec![
|
||||
make_diagnostic("a.rs", 1, 1, DiagnosticSeverity::Error, "err1"),
|
||||
make_diagnostic("b.rs", 2, 3, DiagnosticSeverity::Warning, "warn1"),
|
||||
];
|
||||
let result = LspManager::format_diagnostics(&diags);
|
||||
let lines: Vec<&str> = result.lines().collect();
|
||||
assert_eq!(lines.len(), 2);
|
||||
assert!(lines[0].contains("[ERROR]"));
|
||||
assert!(lines[1].contains("[WARNING]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_diagnostics_with_source_and_code() {
|
||||
let mut d = make_diagnostic(
|
||||
"main.rs",
|
||||
5,
|
||||
1,
|
||||
DiagnosticSeverity::Error,
|
||||
"mismatched types",
|
||||
);
|
||||
d.source = Some("rust-analyzer".to_string());
|
||||
d.code = Some("E0308".to_string());
|
||||
let result = LspManager::format_diagnostics(&[d]);
|
||||
assert!(result.contains("(rust-analyzer)"), "result = {}", result);
|
||||
assert!(result.contains("[E0308]"), "result = {}", result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diagnostic_severity_ordering() {
|
||||
assert!(DiagnosticSeverity::Error < DiagnosticSeverity::Warning);
|
||||
assert!(DiagnosticSeverity::Warning < DiagnosticSeverity::Information);
|
||||
assert!(DiagnosticSeverity::Information < DiagnosticSeverity::Hint);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diagnostic_severity_as_str() {
|
||||
assert_eq!(DiagnosticSeverity::Error.as_str(), "error");
|
||||
assert_eq!(DiagnosticSeverity::Warning.as_str(), "warning");
|
||||
assert_eq!(DiagnosticSeverity::Information.as_str(), "info");
|
||||
assert_eq!(DiagnosticSeverity::Hint.as_str(), "hint");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lsp_server_config_serialization() {
|
||||
let cfg = make_config("rust-analyzer");
|
||||
let json = serde_json::to_string(&cfg).unwrap();
|
||||
let back: LspServerConfig = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(back.name, "rust-analyzer");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_trait() {
|
||||
let mgr = LspManager::default();
|
||||
assert!(mgr.servers().is_empty());
|
||||
}
|
||||
}
|
||||
881
src-rust/crates/core/src/memdir.rs
Normal file
881
src-rust/crates/core/src/memdir.rs
Normal file
|
|
@ -0,0 +1,881 @@
|
|||
//! Memory directory (memdir) system.
|
||||
//!
|
||||
//! Provides persistent, file-based memory across sessions. Mirrors the
|
||||
//! TypeScript modules under `src/memdir/`:
|
||||
//! - `memoryScan.ts` → `scan_memory_dir`, `parse_frontmatter_quick`, `format_memory_manifest`
|
||||
//! - `memoryAge.ts` → `memory_age_days`, `memory_freshness_text`, `memory_freshness_note`
|
||||
//! - `memdir.ts` → `build_memory_prompt_content`, `load_memory_index`, `ensure_memory_dir_exists`
|
||||
//! - `paths.ts` → `auto_memory_path`, `is_auto_memory_enabled`
|
||||
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Memory type taxonomy
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// The four canonical memory types.
|
||||
/// Matches the TypeScript `MemoryType` union in `memoryTypes.ts`.
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum MemoryType {
|
||||
/// Information about the user's role, goals, and preferences.
|
||||
User,
|
||||
/// Guidance the user has given about how to approach work.
|
||||
Feedback,
|
||||
/// Information about ongoing work, goals, or incidents in the project.
|
||||
Project,
|
||||
/// Pointers to where information lives in external systems.
|
||||
Reference,
|
||||
}
|
||||
|
||||
impl MemoryType {
|
||||
/// Parse a raw frontmatter value into a `MemoryType`.
|
||||
/// Returns `None` for missing or unrecognised values (legacy files degrade gracefully).
|
||||
pub fn parse(raw: &str) -> Option<Self> {
|
||||
match raw.trim() {
|
||||
"user" => Some(Self::User),
|
||||
"feedback" => Some(Self::Feedback),
|
||||
"project" => Some(Self::Project),
|
||||
"reference" => Some(Self::Reference),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Display as a lowercase string.
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::User => "user",
|
||||
Self::Feedback => "feedback",
|
||||
Self::Project => "project",
|
||||
Self::Reference => "reference",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Memory file metadata and content
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Scanned metadata for a single memory file (without the full body).
|
||||
/// Mirrors `MemoryHeader` in `memoryScan.ts`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MemoryFileMeta {
|
||||
/// Filename relative to the memory directory (e.g. `user_role.md`).
|
||||
pub filename: String,
|
||||
/// Absolute path to the file.
|
||||
pub path: PathBuf,
|
||||
/// `name:` frontmatter field.
|
||||
pub name: Option<String>,
|
||||
/// `description:` frontmatter field (used for relevance scoring).
|
||||
pub description: Option<String>,
|
||||
/// `type:` frontmatter field.
|
||||
pub memory_type: Option<MemoryType>,
|
||||
/// File modification time in seconds since UNIX epoch.
|
||||
pub modified_secs: u64,
|
||||
}
|
||||
|
||||
/// A fully-loaded memory file including its body.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MemoryFile {
|
||||
pub meta: MemoryFileMeta,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Directory scanning
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Maximum number of memory files kept after sorting.
|
||||
/// Matches `MAX_MEMORY_FILES` in `memoryScan.ts`.
|
||||
const MAX_MEMORY_FILES: usize = 200;
|
||||
|
||||
/// Number of lines scanned for frontmatter.
|
||||
/// Matches `FRONTMATTER_MAX_LINES` in `memoryScan.ts`.
|
||||
const FRONTMATTER_MAX_LINES: usize = 30;
|
||||
|
||||
/// Scan a memory directory, returning metadata for all `.md` files
|
||||
/// (excluding `MEMORY.md`), sorted newest-first, capped at `MAX_MEMORY_FILES`.
|
||||
///
|
||||
/// This is a synchronous scan used during system-prompt assembly.
|
||||
/// Mirrors `scanMemoryFiles` in `memoryScan.ts` (async version; this is the
|
||||
/// sync equivalent used at prompt-build time).
|
||||
pub fn scan_memory_dir(dir: &Path) -> Vec<MemoryFileMeta> {
|
||||
let mut files: Vec<MemoryFileMeta> = Vec::new();
|
||||
|
||||
if !dir.exists() {
|
||||
return files;
|
||||
}
|
||||
|
||||
// Walk recursively using `walkdir`-style manual recursion to stay
|
||||
// dependency-free (only std).
|
||||
collect_md_files(dir, dir, &mut files);
|
||||
|
||||
// Sort newest-first.
|
||||
files.sort_by(|a, b| b.modified_secs.cmp(&a.modified_secs));
|
||||
files.truncate(MAX_MEMORY_FILES);
|
||||
files
|
||||
}
|
||||
|
||||
/// Recursively collect `.md` files (excluding `MEMORY.md`) from `current_dir`.
|
||||
fn collect_md_files(base: &Path, current_dir: &Path, out: &mut Vec<MemoryFileMeta>) {
|
||||
let Ok(entries) = std::fs::read_dir(current_dir) else {
|
||||
return;
|
||||
};
|
||||
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
collect_md_files(base, &path, out);
|
||||
} else if path.extension().map(|e| e == "md").unwrap_or(false) {
|
||||
let file_name = path.file_name().map(|n| n.to_string_lossy().into_owned()).unwrap_or_default();
|
||||
if file_name == "MEMORY.md" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let modified_secs = entry
|
||||
.metadata()
|
||||
.and_then(|m| m.modified())
|
||||
.map(|t| t.duration_since(UNIX_EPOCH).unwrap_or_default().as_secs())
|
||||
.unwrap_or(0);
|
||||
|
||||
let (name, description, memory_type) =
|
||||
if let Ok(content) = std::fs::read_to_string(&path) {
|
||||
parse_frontmatter_quick(&content)
|
||||
} else {
|
||||
(None, None, None)
|
||||
};
|
||||
|
||||
// Relative path from the memory dir root.
|
||||
let relative = path
|
||||
.strip_prefix(base)
|
||||
.map(|p| p.to_string_lossy().into_owned())
|
||||
.unwrap_or_else(|_| file_name.clone());
|
||||
|
||||
out.push(MemoryFileMeta {
|
||||
filename: relative,
|
||||
path,
|
||||
name,
|
||||
description,
|
||||
memory_type,
|
||||
modified_secs,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse YAML frontmatter from the first `FRONTMATTER_MAX_LINES` lines without
|
||||
/// a full YAML parser. Returns `(name, description, memory_type)`.
|
||||
///
|
||||
/// Mirrors `parseFrontmatter` usage in `memoryScan.ts`.
|
||||
pub fn parse_frontmatter_quick(
|
||||
content: &str,
|
||||
) -> (Option<String>, Option<String>, Option<MemoryType>) {
|
||||
let mut name = None;
|
||||
let mut description = None;
|
||||
let mut memory_type = None;
|
||||
|
||||
let lines: Vec<&str> = content.lines().take(FRONTMATTER_MAX_LINES).collect();
|
||||
|
||||
// Frontmatter must start with `---`
|
||||
if lines.first().map(|l| l.trim() != "---").unwrap_or(true) {
|
||||
return (name, description, memory_type);
|
||||
}
|
||||
|
||||
for line in &lines[1..] {
|
||||
if line.trim() == "---" {
|
||||
break;
|
||||
}
|
||||
if let Some(rest) = line.strip_prefix("name:") {
|
||||
name = Some(rest.trim().trim_matches('"').trim_matches('\'').to_string());
|
||||
} else if let Some(rest) = line.strip_prefix("description:") {
|
||||
description = Some(rest.trim().trim_matches('"').trim_matches('\'').to_string());
|
||||
} else if let Some(rest) = line.strip_prefix("type:") {
|
||||
memory_type = MemoryType::parse(rest.trim().trim_matches('"').trim_matches('\''));
|
||||
}
|
||||
}
|
||||
|
||||
(name, description, memory_type)
|
||||
}
|
||||
|
||||
/// Format memory headers as a text manifest: one entry per file with
|
||||
/// `[type] filename (iso-timestamp): description`.
|
||||
///
|
||||
/// Mirrors `formatMemoryManifest` in `memoryScan.ts`.
|
||||
pub fn format_memory_manifest(memories: &[MemoryFileMeta]) -> String {
|
||||
memories
|
||||
.iter()
|
||||
.map(|m| {
|
||||
let tag = m
|
||||
.memory_type
|
||||
.as_ref()
|
||||
.map(|t| format!("[{}] ", t.as_str()))
|
||||
.unwrap_or_default();
|
||||
|
||||
// Convert modified_secs to an ISO-8601-like timestamp.
|
||||
let ts = format_unix_secs_iso(m.modified_secs);
|
||||
|
||||
match &m.description {
|
||||
Some(desc) => format!("- {}{} ({}): {}", tag, m.filename, ts, desc),
|
||||
None => format!("- {}{} ({})", tag, m.filename, ts),
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
/// Minimal ISO-8601 formatter for a Unix timestamp (no external deps).
|
||||
fn format_unix_secs_iso(secs: u64) -> String {
|
||||
// We use a very lightweight implementation to avoid pulling in chrono here
|
||||
// (chrono is already a workspace dep but we want this module to stay lean).
|
||||
// Accuracy to the day is sufficient for memory manifests.
|
||||
let days_since_epoch = secs / 86400;
|
||||
// Julian Day Number for 1970-01-01 is 2440588.
|
||||
let jdn = days_since_epoch as u32 + 2440588;
|
||||
let (y, m, d) = jdn_to_ymd(jdn);
|
||||
let hh = (secs % 86400) / 3600;
|
||||
let mm = (secs % 3600) / 60;
|
||||
let ss = secs % 60;
|
||||
format!("{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z", y, m, d, hh, mm, ss)
|
||||
}
|
||||
|
||||
/// Convert a Julian Day Number to (year, month, day).
|
||||
fn jdn_to_ymd(jdn: u32) -> (u32, u32, u32) {
|
||||
let a = jdn + 32044;
|
||||
let b = (4 * a + 3) / 146097;
|
||||
let c = a - (146097 * b) / 4;
|
||||
let d = (4 * c + 3) / 1461;
|
||||
let e = c - (1461 * d) / 4;
|
||||
let m = (5 * e + 2) / 153;
|
||||
let day = e - (153 * m + 2) / 5 + 1;
|
||||
let month = m + 3 - 12 * (m / 10);
|
||||
let year = 100 * b + d - 4800 + m / 10;
|
||||
(year, month, day)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Memory age / freshness
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Days elapsed since `modified_secs`. Floor-rounded; clamped to 0 for
|
||||
/// future mtimes (clock skew).
|
||||
///
|
||||
/// Mirrors `memoryAgeDays` in `memoryAge.ts`.
|
||||
pub fn memory_age_days(modified_secs: u64) -> u64 {
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
(now.saturating_sub(modified_secs)) / 86400
|
||||
}
|
||||
|
||||
/// Human-readable age string. Models are poor at date arithmetic — a raw
|
||||
/// ISO timestamp does not trigger staleness reasoning the way "47 days ago" does.
|
||||
///
|
||||
/// Mirrors `memoryAge` in `memoryAge.ts`.
|
||||
pub fn memory_age(modified_secs: u64) -> String {
|
||||
let d = memory_age_days(modified_secs);
|
||||
match d {
|
||||
0 => "today".to_string(),
|
||||
1 => "yesterday".to_string(),
|
||||
n => format!("{} days ago", n),
|
||||
}
|
||||
}
|
||||
|
||||
/// Plain-text staleness caveat for memories > 1 day old.
|
||||
/// Returns an empty string for fresh memories (today / yesterday).
|
||||
///
|
||||
/// Mirrors `memoryFreshnessText` in `memoryAge.ts`.
|
||||
pub fn memory_freshness_text(modified_secs: u64) -> String {
|
||||
let d = memory_age_days(modified_secs);
|
||||
if d <= 1 {
|
||||
return String::new();
|
||||
}
|
||||
format!(
|
||||
"This memory is {} days old. \
|
||||
Memories are point-in-time observations, not live state — \
|
||||
claims about code behavior or file:line citations may be outdated. \
|
||||
Verify against current code before asserting as fact.",
|
||||
d
|
||||
)
|
||||
}
|
||||
|
||||
/// Per-memory staleness note wrapped in `<system-reminder>` tags.
|
||||
/// Returns an empty string for memories ≤ 1 day old.
|
||||
///
|
||||
/// Mirrors `memoryFreshnessNote` in `memoryAge.ts`.
|
||||
pub fn memory_freshness_note(modified_secs: u64) -> String {
|
||||
let text = memory_freshness_text(modified_secs);
|
||||
if text.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
format!("<system-reminder>{}</system-reminder>\n", text)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Path resolution
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Entrypoint filename within the memory directory.
|
||||
pub const MEMORY_ENTRYPOINT: &str = "MEMORY.md";
|
||||
|
||||
/// Maximum number of lines loaded from `MEMORY.md`.
|
||||
/// Matches `MAX_ENTRYPOINT_LINES` in `memdir.ts`.
|
||||
pub const MAX_ENTRYPOINT_LINES: usize = 200;
|
||||
|
||||
/// Maximum bytes loaded from `MEMORY.md`.
|
||||
/// Matches `MAX_ENTRYPOINT_BYTES` in `memdir.ts`.
|
||||
pub const MAX_ENTRYPOINT_BYTES: usize = 25_000;
|
||||
|
||||
/// Compute the auto-memory directory path for a project root.
|
||||
///
|
||||
/// Resolution order (mirrors `getAutoMemPath` in `paths.ts`):
|
||||
/// 1. `CLAUDE_COWORK_MEMORY_PATH_OVERRIDE` env var (full-path override).
|
||||
/// 2. `<CLAUDE_CODE_REMOTE_MEMORY_DIR>/projects/<sanitized-root>/memory/`
|
||||
/// when `CLAUDE_CODE_REMOTE_MEMORY_DIR` is set.
|
||||
/// 3. `~/.claude/projects/<sanitized-root>/memory/` (default).
|
||||
pub fn auto_memory_path(project_root: &Path) -> PathBuf {
|
||||
// 1. Cowork full-path override.
|
||||
if let Ok(override_path) = std::env::var("CLAUDE_COWORK_MEMORY_PATH_OVERRIDE") {
|
||||
if !override_path.is_empty() {
|
||||
return PathBuf::from(override_path);
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Determine the memory base directory.
|
||||
let memory_base = std::env::var("CLAUDE_CODE_REMOTE_MEMORY_DIR")
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|_| {
|
||||
dirs::home_dir()
|
||||
.unwrap_or_else(|| PathBuf::from("."))
|
||||
.join(".claude")
|
||||
});
|
||||
|
||||
// 3. Sanitize the project root into a safe directory name.
|
||||
let sanitized = sanitize_path_component(&project_root.to_string_lossy());
|
||||
|
||||
memory_base.join("projects").join(sanitized).join("memory")
|
||||
}
|
||||
|
||||
/// Sanitize an arbitrary string into a directory-name-safe component.
|
||||
/// Matches `sanitizePath` used inside `getAutoMemPath` in `paths.ts`.
|
||||
pub fn sanitize_path_component(s: &str) -> String {
|
||||
let sanitized: String = s
|
||||
.chars()
|
||||
.map(|c| {
|
||||
if c.is_alphanumeric() || c == '-' || c == '_' || c == '.' {
|
||||
c
|
||||
} else {
|
||||
'_'
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
sanitized.trim_matches('_').to_string()
|
||||
}
|
||||
|
||||
/// Whether the auto-memory system is enabled for this session.
|
||||
///
|
||||
/// Priority chain (mirrors `isAutoMemoryEnabled` in `paths.ts`):
|
||||
/// 1. `CLAUDE_CODE_DISABLE_AUTO_MEMORY` — truthy → OFF, falsy (but defined) → ON.
|
||||
/// 2. `CLAUDE_CODE_SIMPLE` (--bare) → OFF.
|
||||
/// 3. Remote mode without `CLAUDE_CODE_REMOTE_MEMORY_DIR` → OFF.
|
||||
/// 4. `settings_enabled` parameter (from settings.json `autoMemoryEnabled` field).
|
||||
/// 5. Default: enabled.
|
||||
pub fn is_auto_memory_enabled(settings_enabled: Option<bool>) -> bool {
|
||||
if let Ok(val) = std::env::var("CLAUDE_CODE_DISABLE_AUTO_MEMORY") {
|
||||
// Truthy values (non-empty, non-"0", non-"false") disable memory.
|
||||
match val.to_lowercase().as_str() {
|
||||
"" | "0" | "false" | "no" | "off" => return true, // defined-falsy → ON
|
||||
_ => return false, // truthy → OFF
|
||||
}
|
||||
}
|
||||
|
||||
if std::env::var("CLAUDE_CODE_SIMPLE").is_ok() {
|
||||
return false;
|
||||
}
|
||||
|
||||
if std::env::var("CLAUDE_CODE_REMOTE").is_ok()
|
||||
&& std::env::var("CLAUDE_CODE_REMOTE_MEMORY_DIR").is_err()
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
settings_enabled.unwrap_or(true)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Index loading and truncation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Result of loading and (optionally) truncating the `MEMORY.md` entrypoint.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EntrypointTruncation {
|
||||
pub content: String,
|
||||
pub line_count: usize,
|
||||
pub byte_count: usize,
|
||||
pub was_line_truncated: bool,
|
||||
pub was_byte_truncated: bool,
|
||||
}
|
||||
|
||||
/// Truncate `MEMORY.md` content to `MAX_ENTRYPOINT_LINES` lines and
|
||||
/// `MAX_ENTRYPOINT_BYTES` bytes, appending a warning when either cap fires.
|
||||
///
|
||||
/// Mirrors `truncateEntrypointContent` in `memdir.ts`.
|
||||
pub fn truncate_entrypoint_content(raw: &str) -> EntrypointTruncation {
|
||||
let trimmed = raw.trim();
|
||||
let content_lines: Vec<&str> = trimmed.lines().collect();
|
||||
let line_count = content_lines.len();
|
||||
let byte_count = trimmed.len();
|
||||
|
||||
let was_line_truncated = line_count > MAX_ENTRYPOINT_LINES;
|
||||
let was_byte_truncated = byte_count > MAX_ENTRYPOINT_BYTES;
|
||||
|
||||
if !was_line_truncated && !was_byte_truncated {
|
||||
return EntrypointTruncation {
|
||||
content: trimmed.to_string(),
|
||||
line_count,
|
||||
byte_count,
|
||||
was_line_truncated: false,
|
||||
was_byte_truncated: false,
|
||||
};
|
||||
}
|
||||
|
||||
let mut truncated = if was_line_truncated {
|
||||
content_lines[..MAX_ENTRYPOINT_LINES].join("\n")
|
||||
} else {
|
||||
trimmed.to_string()
|
||||
};
|
||||
|
||||
if truncated.len() > MAX_ENTRYPOINT_BYTES {
|
||||
let cut_at = truncated[..MAX_ENTRYPOINT_BYTES]
|
||||
.rfind('\n')
|
||||
.unwrap_or(MAX_ENTRYPOINT_BYTES);
|
||||
truncated.truncate(cut_at);
|
||||
}
|
||||
|
||||
let reason = match (was_line_truncated, was_byte_truncated) {
|
||||
(true, false) => format!("{} lines (limit: {})", line_count, MAX_ENTRYPOINT_LINES),
|
||||
(false, true) => format!(
|
||||
"{} bytes (limit: {}) — index entries are too long",
|
||||
byte_count, MAX_ENTRYPOINT_BYTES
|
||||
),
|
||||
_ => format!(
|
||||
"{} lines and {} bytes",
|
||||
line_count, byte_count
|
||||
),
|
||||
};
|
||||
|
||||
truncated.push_str(&format!(
|
||||
"\n\n> WARNING: {} is {}. Only part of it was loaded. \
|
||||
Keep index entries to one line under ~200 chars; move detail into topic files.",
|
||||
MEMORY_ENTRYPOINT, reason
|
||||
));
|
||||
|
||||
EntrypointTruncation {
|
||||
content: truncated,
|
||||
line_count,
|
||||
byte_count,
|
||||
was_line_truncated,
|
||||
was_byte_truncated,
|
||||
}
|
||||
}
|
||||
|
||||
/// Load and truncate the `MEMORY.md` index from `memory_dir`.
|
||||
/// Returns `None` when the file does not exist or is empty.
|
||||
///
|
||||
/// Mirrors the entrypoint-reading path in `buildMemoryPrompt` / `loadMemoryPrompt`.
|
||||
pub fn load_memory_index(memory_dir: &Path) -> Option<EntrypointTruncation> {
|
||||
let index_path = memory_dir.join(MEMORY_ENTRYPOINT);
|
||||
if !index_path.exists() {
|
||||
return None;
|
||||
}
|
||||
let raw = std::fs::read_to_string(&index_path).ok()?;
|
||||
if raw.trim().is_empty() {
|
||||
return None;
|
||||
}
|
||||
Some(truncate_entrypoint_content(&raw))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// System-prompt memory content builder
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Build the memory content string to inject into the system prompt's
|
||||
/// `<memory>` block.
|
||||
///
|
||||
/// Always includes the `MEMORY.md` index when it exists.
|
||||
/// Called during `build_system_prompt` → `SystemPromptOptions::memory_content`.
|
||||
pub fn build_memory_prompt_content(memory_dir: &Path) -> String {
|
||||
let mut parts: Vec<String> = Vec::new();
|
||||
|
||||
if let Some(index) = load_memory_index(memory_dir) {
|
||||
parts.push(format!("## Memory Index (MEMORY.md)\n{}", index.content));
|
||||
}
|
||||
|
||||
parts.join("\n\n")
|
||||
}
|
||||
|
||||
/// Ensure the memory directory exists, creating it (and any parents) if needed.
|
||||
/// Errors are silently swallowed (the Write tool will surface them if needed).
|
||||
///
|
||||
/// Mirrors `ensureMemoryDirExists` in `memdir.ts`.
|
||||
pub fn ensure_memory_dir_exists(memory_dir: &Path) {
|
||||
if let Err(e) = std::fs::create_dir_all(memory_dir) {
|
||||
// Log at debug level so --debug shows why, but don't abort.
|
||||
tracing::debug!(
|
||||
dir = %memory_dir.display(),
|
||||
error = %e,
|
||||
"ensureMemoryDirExists failed"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Simple relevance search (no LLM side-query)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Find and load the most relevant memory files for a query using a
|
||||
/// lightweight TF-IDF-style keyword score.
|
||||
///
|
||||
/// The full Sonnet side-query (`findRelevantMemories` in TypeScript) lives
|
||||
/// in `cc-query`; this function provides a cheaper fallback for contexts
|
||||
/// where an API call is not available.
|
||||
pub fn find_relevant_memories_simple(
|
||||
memory_dir: &Path,
|
||||
query: &str,
|
||||
max_files: usize,
|
||||
) -> Vec<MemoryFile> {
|
||||
let metas = scan_memory_dir(memory_dir);
|
||||
let query_lower = query.to_lowercase();
|
||||
let query_words: Vec<&str> = query_lower.split_whitespace().collect();
|
||||
|
||||
if query_words.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let mut scored: Vec<(f32, MemoryFileMeta)> = metas
|
||||
.into_iter()
|
||||
.filter_map(|meta| {
|
||||
let desc = meta.description.as_deref().unwrap_or("").to_lowercase();
|
||||
let name = meta.name.as_deref().unwrap_or("").to_lowercase();
|
||||
let filename = meta.filename.to_lowercase();
|
||||
|
||||
let score: f32 = query_words
|
||||
.iter()
|
||||
.map(|w| {
|
||||
let in_name = if name.contains(*w) { 2.0_f32 } else { 0.0 };
|
||||
let in_desc = if desc.contains(*w) { 1.0_f32 } else { 0.0 };
|
||||
let in_file = if filename.contains(*w) { 0.5_f32 } else { 0.0 };
|
||||
in_name + in_desc + in_file
|
||||
})
|
||||
.sum();
|
||||
|
||||
if score > 0.0 { Some((score, meta)) } else { None }
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Sort highest score first.
|
||||
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
scored
|
||||
.into_iter()
|
||||
.take(max_files)
|
||||
.filter_map(|(_, meta)| {
|
||||
let content = std::fs::read_to_string(&meta.path).ok()?;
|
||||
Some(MemoryFile { meta, content })
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Team memory helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Return the team-memory sub-directory path.
|
||||
/// Mirrors `getTeamMemPath` in `teamMemPaths.ts`.
|
||||
pub fn team_memory_path(auto_memory_dir: &Path) -> PathBuf {
|
||||
auto_memory_dir.join("team")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Write as IoWrite;
|
||||
|
||||
// Helpers ----------------------------------------------------------------
|
||||
|
||||
fn make_temp_dir() -> tempfile::TempDir {
|
||||
tempfile::tempdir().expect("tempdir")
|
||||
}
|
||||
|
||||
fn write_file(dir: &Path, name: &str, content: &str) {
|
||||
let path = dir.join(name);
|
||||
if let Some(parent) = path.parent() {
|
||||
std::fs::create_dir_all(parent).unwrap();
|
||||
}
|
||||
let mut f = std::fs::File::create(&path).unwrap();
|
||||
f.write_all(content.as_bytes()).unwrap();
|
||||
}
|
||||
|
||||
// ---- parse_frontmatter_quick -------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_parse_frontmatter_full() {
|
||||
let content = "---\nname: My Memory\ndescription: A test description\ntype: feedback\n---\n\nBody text.";
|
||||
let (name, desc, mt) = parse_frontmatter_quick(content);
|
||||
assert_eq!(name.as_deref(), Some("My Memory"));
|
||||
assert_eq!(desc.as_deref(), Some("A test description"));
|
||||
assert_eq!(mt, Some(MemoryType::Feedback));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_frontmatter_no_frontmatter() {
|
||||
let content = "Just plain text.";
|
||||
let (name, desc, mt) = parse_frontmatter_quick(content);
|
||||
assert!(name.is_none());
|
||||
assert!(desc.is_none());
|
||||
assert!(mt.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_frontmatter_quoted_values() {
|
||||
let content = "---\nname: \"Quoted Name\"\ndescription: 'Single quoted'\ntype: user\n---";
|
||||
let (name, desc, mt) = parse_frontmatter_quick(content);
|
||||
assert_eq!(name.as_deref(), Some("Quoted Name"));
|
||||
assert_eq!(desc.as_deref(), Some("Single quoted"));
|
||||
assert_eq!(mt, Some(MemoryType::User));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_frontmatter_unknown_type() {
|
||||
let content = "---\ntype: unknown_type\n---";
|
||||
let (_, _, mt) = parse_frontmatter_quick(content);
|
||||
assert!(mt.is_none());
|
||||
}
|
||||
|
||||
// ---- memory_age_days ---------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_memory_age_today() {
|
||||
let now_secs = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
assert_eq!(memory_age_days(now_secs), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_age_one_day_ago() {
|
||||
let yesterday = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs()
|
||||
.saturating_sub(86_400);
|
||||
assert_eq!(memory_age_days(yesterday), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_age_future_clamps_to_zero() {
|
||||
let far_future = u64::MAX;
|
||||
assert_eq!(memory_age_days(far_future), 0);
|
||||
}
|
||||
|
||||
// ---- memory_freshness_text ---------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_freshness_text_fresh() {
|
||||
let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs();
|
||||
assert!(memory_freshness_text(now).is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_freshness_text_stale() {
|
||||
let old = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs()
|
||||
.saturating_sub(10 * 86_400); // 10 days ago
|
||||
let text = memory_freshness_text(old);
|
||||
assert!(text.contains("10 days old"));
|
||||
assert!(text.contains("point-in-time"));
|
||||
}
|
||||
|
||||
// ---- memory_freshness_note ---------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_freshness_note_fresh_is_empty() {
|
||||
let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs();
|
||||
assert!(memory_freshness_note(now).is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_freshness_note_stale_has_tags() {
|
||||
let old = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs()
|
||||
.saturating_sub(5 * 86_400);
|
||||
let note = memory_freshness_note(old);
|
||||
assert!(note.contains("<system-reminder>"));
|
||||
assert!(note.contains("</system-reminder>"));
|
||||
}
|
||||
|
||||
// ---- truncate_entrypoint_content ---------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_truncate_no_truncation_needed() {
|
||||
let content = "line1\nline2\nline3";
|
||||
let result = truncate_entrypoint_content(content);
|
||||
assert!(!result.was_line_truncated);
|
||||
assert!(!result.was_byte_truncated);
|
||||
assert_eq!(result.content, content);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_truncate_line_limit() {
|
||||
let content = (0..=MAX_ENTRYPOINT_LINES)
|
||||
.map(|i| format!("line {}", i))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
let result = truncate_entrypoint_content(&content);
|
||||
assert!(result.was_line_truncated);
|
||||
assert!(result.content.contains("WARNING"));
|
||||
}
|
||||
|
||||
// ---- sanitize_path_component -------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_sanitize_path_component() {
|
||||
assert_eq!(sanitize_path_component("/home/user/project"), "_home_user_project");
|
||||
assert_eq!(sanitize_path_component("normal-name_123"), "normal-name_123");
|
||||
assert_eq!(sanitize_path_component("C:\\Users\\foo"), "C__Users_foo");
|
||||
}
|
||||
|
||||
// ---- load_memory_index -------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_load_memory_index_nonexistent() {
|
||||
let dir = make_temp_dir();
|
||||
assert!(load_memory_index(dir.path()).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_memory_index_empty() {
|
||||
let dir = make_temp_dir();
|
||||
write_file(dir.path(), "MEMORY.md", " ");
|
||||
assert!(load_memory_index(dir.path()).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_memory_index_with_content() {
|
||||
let dir = make_temp_dir();
|
||||
write_file(dir.path(), "MEMORY.md", "- [test.md](test.md) — something");
|
||||
let result = load_memory_index(dir.path()).unwrap();
|
||||
assert!(result.content.contains("test.md"));
|
||||
}
|
||||
|
||||
// ---- scan_memory_dir ---------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_scan_excludes_memory_md() {
|
||||
let dir = make_temp_dir();
|
||||
write_file(dir.path(), "MEMORY.md", "# index");
|
||||
write_file(dir.path(), "user_role.md", "---\nname: Role\n---");
|
||||
let metas = scan_memory_dir(dir.path());
|
||||
assert_eq!(metas.len(), 1);
|
||||
assert_eq!(metas[0].filename, "user_role.md");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scan_empty_dir() {
|
||||
let dir = make_temp_dir();
|
||||
assert!(scan_memory_dir(dir.path()).is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scan_nonexistent_dir() {
|
||||
let path = PathBuf::from("/tmp/nonexistent_memory_dir_cc_rust_test_xyz");
|
||||
assert!(scan_memory_dir(&path).is_empty());
|
||||
}
|
||||
|
||||
// ---- format_memory_manifest --------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_format_memory_manifest_with_description() {
|
||||
let meta = MemoryFileMeta {
|
||||
filename: "user_role.md".to_string(),
|
||||
path: PathBuf::from("user_role.md"),
|
||||
name: Some("User Role".to_string()),
|
||||
description: Some("The user is a data scientist".to_string()),
|
||||
memory_type: Some(MemoryType::User),
|
||||
modified_secs: 0,
|
||||
};
|
||||
let manifest = format_memory_manifest(&[meta]);
|
||||
assert!(manifest.contains("[user]"));
|
||||
assert!(manifest.contains("user_role.md"));
|
||||
assert!(manifest.contains("data scientist"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_memory_manifest_no_description() {
|
||||
let meta = MemoryFileMeta {
|
||||
filename: "ref.md".to_string(),
|
||||
path: PathBuf::from("ref.md"),
|
||||
name: None,
|
||||
description: None,
|
||||
memory_type: None,
|
||||
modified_secs: 0,
|
||||
};
|
||||
let manifest = format_memory_manifest(&[meta]);
|
||||
assert!(manifest.contains("ref.md"));
|
||||
// No description separator colon
|
||||
assert!(!manifest.contains("ref.md ("));
|
||||
}
|
||||
|
||||
// ---- MemoryType --------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_memory_type_roundtrip() {
|
||||
for (s, expected) in [
|
||||
("user", MemoryType::User),
|
||||
("feedback", MemoryType::Feedback),
|
||||
("project", MemoryType::Project),
|
||||
("reference", MemoryType::Reference),
|
||||
] {
|
||||
let parsed = MemoryType::parse(s).unwrap();
|
||||
assert_eq!(parsed, expected);
|
||||
assert_eq!(parsed.as_str(), s);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_type_unknown_returns_none() {
|
||||
assert!(MemoryType::parse("bogus").is_none());
|
||||
}
|
||||
|
||||
// ---- is_auto_memory_enabled -------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_auto_memory_enabled_default() {
|
||||
// No env vars set for this test, settings None → should be enabled.
|
||||
// We can't guarantee the test environment is clean, so just check it
|
||||
// returns a bool without panicking.
|
||||
let _ = is_auto_memory_enabled(None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auto_memory_disabled_by_setting() {
|
||||
// If settings explicitly disable it and no env override, returns false.
|
||||
// We only test the settings-path without touching process env.
|
||||
// Simulate: env vars not set, settings says false.
|
||||
// We can't unset env vars reliably in tests, so just ensure the
|
||||
// function handles Some(false) without panicking.
|
||||
// (The full env-var paths are integration-tested separately.)
|
||||
let _ = is_auto_memory_enabled(Some(false));
|
||||
}
|
||||
}
|
||||
474
src-rust/crates/core/src/migrations.rs
Normal file
474
src-rust/crates/core/src/migrations.rs
Normal file
|
|
@ -0,0 +1,474 @@
|
|||
//! Settings migration framework
|
||||
//! Runs on startup to upgrade settings.json from older versions.
|
||||
//!
|
||||
//! Migrations are derived from the TypeScript originals:
|
||||
//! - src/migrations/migrateFennecToOpus.ts
|
||||
//! - src/migrations/migrateLegacyOpusToCurrent.ts
|
||||
//! - src/migrations/migrateSonnet45ToSonnet46.ts
|
||||
//! - src/migrations/migrateAutoUpdatesToSettings.ts
|
||||
//! - (and several others without separate TS source files)
|
||||
//!
|
||||
//! Each migration is idempotent: it only touches fields it recognises and
|
||||
//! only writes when it actually changes something.
|
||||
|
||||
use serde_json::Value;
|
||||
|
||||
/// A single migration function.
|
||||
/// Returns `true` if the settings object was modified.
|
||||
pub type MigrationFn = fn(&mut Value) -> bool;
|
||||
|
||||
/// All migrations in the order they must be applied.
|
||||
pub const MIGRATIONS: &[(&str, MigrationFn)] = &[
|
||||
("migrate_fennec_to_opus", migrate_fennec_to_opus),
|
||||
("migrate_legacy_opus_to_current", migrate_legacy_opus_to_current),
|
||||
("migrate_opus_to_opus_1m", migrate_opus_to_opus_1m),
|
||||
("migrate_sonnet_1m_to_sonnet_45", migrate_sonnet_1m_to_sonnet_45),
|
||||
("migrate_sonnet_45_to_sonnet_46", migrate_sonnet_45_to_sonnet_46),
|
||||
(
|
||||
"migrate_bypass_permissions_to_settings",
|
||||
migrate_bypass_permissions_to_settings,
|
||||
),
|
||||
(
|
||||
"migrate_repl_bridge_to_remote_control",
|
||||
migrate_repl_bridge_to_remote_control,
|
||||
),
|
||||
("migrate_enable_all_mcp_servers", migrate_enable_all_mcp_servers),
|
||||
("migrate_auto_updates", migrate_auto_updates),
|
||||
("reset_auto_mode_opt_in", reset_auto_mode_opt_in),
|
||||
("reset_pro_to_opus_default", reset_pro_to_opus_default),
|
||||
];
|
||||
|
||||
/// Apply every pending migration to a settings `Value` (must be a JSON object).
|
||||
/// Returns `true` when at least one migration changed the settings.
|
||||
pub fn run_migrations(settings: &mut Value) -> bool {
|
||||
let mut changed = false;
|
||||
for (name, migration) in MIGRATIONS {
|
||||
if migration(settings) {
|
||||
tracing::info!("Applied settings migration: {}", name);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
changed
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Model-name migrations
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Fennec was an internal alias; map to the current Opus line.
|
||||
/// Source: migrateFennecToOpus.ts
|
||||
fn migrate_fennec_to_opus(settings: &mut Value) -> bool {
|
||||
// fennec-latest[1m] → opus[1m], fennec-latest → opus
|
||||
// fennec-fast-latest / opus-4-5-fast → opus[1m] (fast-mode alias)
|
||||
let model = match settings.get("model").and_then(|v: &Value| v.as_str()) {
|
||||
Some(m) => m.to_string(),
|
||||
None => return false,
|
||||
};
|
||||
|
||||
if model.starts_with("fennec-latest[1m]") {
|
||||
settings["model"] = Value::String("opus[1m]".to_string());
|
||||
return true;
|
||||
}
|
||||
if model.starts_with("fennec-latest") {
|
||||
settings["model"] = Value::String("opus".to_string());
|
||||
return true;
|
||||
}
|
||||
if model.starts_with("fennec-fast-latest") || model.starts_with("opus-4-5-fast") {
|
||||
settings["model"] = Value::String("opus[1m]".to_string());
|
||||
settings["fastMode"] = Value::Bool(true);
|
||||
return true;
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Migrate explicit Opus 4.0/4.1 strings to the `opus` alias.
|
||||
/// Source: migrateLegacyOpusToCurrent.ts
|
||||
fn migrate_legacy_opus_to_current(settings: &mut Value) -> bool {
|
||||
const LEGACY_OPUS: &[&str] = &[
|
||||
"claude-opus-4-20250514",
|
||||
"claude-opus-4-1-20250805",
|
||||
"claude-opus-4-0",
|
||||
"claude-opus-4-1",
|
||||
];
|
||||
let model = match settings.get("model").and_then(|v: &Value| v.as_str()) {
|
||||
Some(m) => m.to_string(),
|
||||
None => return false,
|
||||
};
|
||||
if LEGACY_OPUS.contains(&model.as_str()) {
|
||||
settings["model"] = Value::String("opus".to_string());
|
||||
return true;
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Rename the old explicit `claude-opus-4-0` model string (pre-alias era).
|
||||
fn migrate_opus_to_opus_1m(settings: &mut Value) -> bool {
|
||||
rename_model(settings, "claude-opus-4-0", "claude-opus-4-5-20251001")
|
||||
}
|
||||
|
||||
/// Migrate the old Sonnet 1m string to the Sonnet 4.5 release ID.
|
||||
fn migrate_sonnet_1m_to_sonnet_45(settings: &mut Value) -> bool {
|
||||
rename_model(
|
||||
settings,
|
||||
"claude-sonnet-4-0-1m",
|
||||
"claude-sonnet-4-5-20251015",
|
||||
)
|
||||
}
|
||||
|
||||
/// Migrate Sonnet 4.5 explicit IDs to `sonnet` (which resolves to 4.6).
|
||||
/// Source: migrateSonnet45ToSonnet46.ts
|
||||
fn migrate_sonnet_45_to_sonnet_46(settings: &mut Value) -> bool {
|
||||
const SONNET_45_IDS: &[&str] = &[
|
||||
"claude-sonnet-4-5-20250929",
|
||||
"claude-sonnet-4-5-20250929[1m]",
|
||||
"sonnet-4-5-20250929",
|
||||
"sonnet-4-5-20250929[1m]",
|
||||
// Also handle the model strings used in the older Rust migrations table:
|
||||
"claude-sonnet-4-5-20251015",
|
||||
"claude-sonnet-4-5",
|
||||
];
|
||||
|
||||
let model = match settings.get("model").and_then(|v: &Value| v.as_str()) {
|
||||
Some(m) => m.to_string(),
|
||||
None => return false,
|
||||
};
|
||||
|
||||
if SONNET_45_IDS.contains(&model.as_str()) {
|
||||
let has_1m = model.ends_with("[1m]");
|
||||
let new_model = if has_1m { "sonnet[1m]" } else { "sonnet" };
|
||||
settings["model"] = Value::String(new_model.to_string());
|
||||
return true;
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Rename `from` to `to` in the `model`, `defaultModel`, and `mainLoopModel`
|
||||
/// fields. Returns `true` if any field was changed.
|
||||
fn rename_model(settings: &mut Value, from: &str, to: &str) -> bool {
|
||||
let mut changed = false;
|
||||
for key in &["model", "defaultModel", "mainLoopModel"] {
|
||||
if let Some(val) = settings.get_mut(*key) {
|
||||
if val.as_str() == Some(from) {
|
||||
*val = Value::String(to.to_string());
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
changed
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Config-structure migrations
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Move `bypassPermissionsAccepted` boolean into `permissionMode`.
|
||||
fn migrate_bypass_permissions_to_settings(settings: &mut Value) -> bool {
|
||||
let old = match settings.get("bypassPermissionsAccepted").cloned() {
|
||||
Some(v) => v,
|
||||
None => return false,
|
||||
};
|
||||
|
||||
if settings.get("permissionMode").is_none() && old.as_bool().unwrap_or(false) {
|
||||
settings["permissionMode"] = Value::String("bypass".to_string());
|
||||
}
|
||||
|
||||
if let Some(obj) = settings.as_object_mut() {
|
||||
obj.remove("bypassPermissionsAccepted");
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
/// Rename `replBridgeEnabled` → `remoteControlAtStartup`.
|
||||
fn migrate_repl_bridge_to_remote_control(settings: &mut Value) -> bool {
|
||||
let old = match settings.get("replBridgeEnabled").cloned() {
|
||||
Some(v) => v,
|
||||
None => return false,
|
||||
};
|
||||
|
||||
if settings.get("remoteControlAtStartup").is_none() {
|
||||
settings["remoteControlAtStartup"] = old;
|
||||
}
|
||||
|
||||
if let Some(obj) = settings.as_object_mut() {
|
||||
obj.remove("replBridgeEnabled");
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
/// Rename `enableAllProjectMcpServers` → `mcpAutoApprove`.
|
||||
fn migrate_enable_all_mcp_servers(settings: &mut Value) -> bool {
|
||||
let old = match settings.get("enableAllProjectMcpServers").cloned() {
|
||||
Some(v) => v,
|
||||
None => return false,
|
||||
};
|
||||
|
||||
if settings.get("mcpAutoApprove").is_none() {
|
||||
settings["mcpAutoApprove"] = old;
|
||||
}
|
||||
|
||||
if let Some(obj) = settings.as_object_mut() {
|
||||
obj.remove("enableAllProjectMcpServers");
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
/// Migrate `autoUpdatesEnabled` → `autoUpdates`.
|
||||
/// Source: migrateAutoUpdatesToSettings.ts
|
||||
/// The TS version also writes an env-var to settings.json; here we keep the
|
||||
/// simpler structural rename and leave env-var injection to the caller.
|
||||
fn migrate_auto_updates(settings: &mut Value) -> bool {
|
||||
let old = match settings.get("autoUpdatesEnabled").cloned() {
|
||||
Some(v) => v,
|
||||
None => return false,
|
||||
};
|
||||
|
||||
if settings.get("autoUpdates").is_none() {
|
||||
settings["autoUpdates"] = old;
|
||||
}
|
||||
|
||||
if let Some(obj) = settings.as_object_mut() {
|
||||
obj.remove("autoUpdatesEnabled");
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
/// Clear an old sentinel value for the auto-mode opt-in flag.
|
||||
fn reset_auto_mode_opt_in(settings: &mut Value) -> bool {
|
||||
if let Some(val) = settings.get("autoModeOptIn") {
|
||||
if val.as_str() == Some("default_offer_2024") {
|
||||
settings["autoModeOptIn"] = Value::Null;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Reset users who were auto-defaulted to Opus back to Sonnet 4.6.
|
||||
/// Only resets when `modelSetByUser` is not explicitly `true`.
|
||||
fn reset_pro_to_opus_default(settings: &mut Value) -> bool {
|
||||
if let Some(val) = settings.get("model") {
|
||||
if val.as_str() == Some("claude-opus-4-5-20251001") {
|
||||
let set_by_user = settings
|
||||
.get("modelSetByUser")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
if !set_by_user {
|
||||
settings["model"] = Value::String("claude-sonnet-4-6".to_string());
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
fn settings(model: &str) -> Value {
|
||||
json!({ "model": model })
|
||||
}
|
||||
|
||||
// ---- rename_model -------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn rename_model_changes_matching_field() {
|
||||
let mut s = settings("old-model");
|
||||
assert!(rename_model(&mut s, "old-model", "new-model"));
|
||||
assert_eq!(s["model"].as_str(), Some("new-model"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rename_model_no_change_when_different() {
|
||||
let mut s = settings("something-else");
|
||||
assert!(!rename_model(&mut s, "old-model", "new-model"));
|
||||
assert_eq!(s["model"].as_str(), Some("something-else"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rename_model_covers_all_keys() {
|
||||
let mut s = json!({
|
||||
"model": "claude-foo",
|
||||
"defaultModel": "claude-foo",
|
||||
"mainLoopModel": "claude-foo",
|
||||
});
|
||||
assert!(rename_model(&mut s, "claude-foo", "claude-bar"));
|
||||
assert_eq!(s["model"].as_str(), Some("claude-bar"));
|
||||
assert_eq!(s["defaultModel"].as_str(), Some("claude-bar"));
|
||||
assert_eq!(s["mainLoopModel"].as_str(), Some("claude-bar"));
|
||||
}
|
||||
|
||||
// ---- migrate_fennec_to_opus ---------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn fennec_latest_1m_maps_to_opus_1m() {
|
||||
let mut s = settings("fennec-latest[1m]");
|
||||
assert!(migrate_fennec_to_opus(&mut s));
|
||||
assert_eq!(s["model"].as_str(), Some("opus[1m]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fennec_latest_maps_to_opus() {
|
||||
let mut s = settings("fennec-latest");
|
||||
assert!(migrate_fennec_to_opus(&mut s));
|
||||
assert_eq!(s["model"].as_str(), Some("opus"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fennec_fast_maps_to_opus_1m_with_fast_mode() {
|
||||
let mut s = settings("fennec-fast-latest");
|
||||
assert!(migrate_fennec_to_opus(&mut s));
|
||||
assert_eq!(s["model"].as_str(), Some("opus[1m]"));
|
||||
assert_eq!(s["fastMode"].as_bool(), Some(true));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn opus_4_5_fast_maps_to_opus_1m_with_fast_mode() {
|
||||
let mut s = settings("opus-4-5-fast");
|
||||
assert!(migrate_fennec_to_opus(&mut s));
|
||||
assert_eq!(s["model"].as_str(), Some("opus[1m]"));
|
||||
assert_eq!(s["fastMode"].as_bool(), Some(true));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fennec_no_match_returns_false() {
|
||||
let mut s = settings("claude-sonnet-4-6");
|
||||
assert!(!migrate_fennec_to_opus(&mut s));
|
||||
}
|
||||
|
||||
// ---- migrate_legacy_opus_to_current ------------------------------------
|
||||
|
||||
#[test]
|
||||
fn legacy_opus_4_0_maps_to_opus() {
|
||||
let mut s = settings("claude-opus-4-0");
|
||||
assert!(migrate_legacy_opus_to_current(&mut s));
|
||||
assert_eq!(s["model"].as_str(), Some("opus"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn legacy_opus_4_1_maps_to_opus() {
|
||||
let mut s = settings("claude-opus-4-1-20250805");
|
||||
assert!(migrate_legacy_opus_to_current(&mut s));
|
||||
assert_eq!(s["model"].as_str(), Some("opus"));
|
||||
}
|
||||
|
||||
// ---- migrate_sonnet_45_to_sonnet_46 ------------------------------------
|
||||
|
||||
#[test]
|
||||
fn sonnet_45_explicit_id_maps_to_sonnet() {
|
||||
let mut s = settings("claude-sonnet-4-5-20250929");
|
||||
assert!(migrate_sonnet_45_to_sonnet_46(&mut s));
|
||||
assert_eq!(s["model"].as_str(), Some("sonnet"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sonnet_45_1m_maps_to_sonnet_1m() {
|
||||
let mut s = settings("claude-sonnet-4-5-20250929[1m]");
|
||||
assert!(migrate_sonnet_45_to_sonnet_46(&mut s));
|
||||
assert_eq!(s["model"].as_str(), Some("sonnet[1m]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sonnet_46_is_untouched() {
|
||||
let mut s = settings("claude-sonnet-4-6");
|
||||
assert!(!migrate_sonnet_45_to_sonnet_46(&mut s));
|
||||
}
|
||||
|
||||
// ---- struct migrations -------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn bypass_permissions_migrates_and_removes_old_key() {
|
||||
let mut s = json!({ "bypassPermissionsAccepted": true });
|
||||
assert!(migrate_bypass_permissions_to_settings(&mut s));
|
||||
assert!(s.get("bypassPermissionsAccepted").is_none());
|
||||
assert_eq!(s["permissionMode"].as_str(), Some("bypass"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bypass_permissions_false_does_not_set_mode() {
|
||||
let mut s = json!({ "bypassPermissionsAccepted": false });
|
||||
assert!(migrate_bypass_permissions_to_settings(&mut s));
|
||||
assert!(s.get("permissionMode").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn repl_bridge_renames_field() {
|
||||
let mut s = json!({ "replBridgeEnabled": true });
|
||||
assert!(migrate_repl_bridge_to_remote_control(&mut s));
|
||||
assert!(s.get("replBridgeEnabled").is_none());
|
||||
assert_eq!(s["remoteControlAtStartup"].as_bool(), Some(true));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enable_all_mcp_renames_field() {
|
||||
let mut s = json!({ "enableAllProjectMcpServers": true });
|
||||
assert!(migrate_enable_all_mcp_servers(&mut s));
|
||||
assert!(s.get("enableAllProjectMcpServers").is_none());
|
||||
assert_eq!(s["mcpAutoApprove"].as_bool(), Some(true));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auto_updates_renames_field() {
|
||||
let mut s = json!({ "autoUpdatesEnabled": false });
|
||||
assert!(migrate_auto_updates(&mut s));
|
||||
assert!(s.get("autoUpdatesEnabled").is_none());
|
||||
assert_eq!(s["autoUpdates"].as_bool(), Some(false));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reset_auto_mode_clears_sentinel() {
|
||||
let mut s = json!({ "autoModeOptIn": "default_offer_2024" });
|
||||
assert!(reset_auto_mode_opt_in(&mut s));
|
||||
assert!(s["autoModeOptIn"].is_null());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reset_auto_mode_leaves_other_values() {
|
||||
let mut s = json!({ "autoModeOptIn": "user_opted_in" });
|
||||
assert!(!reset_auto_mode_opt_in(&mut s));
|
||||
assert_eq!(s["autoModeOptIn"].as_str(), Some("user_opted_in"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reset_pro_opus_default_resets_when_not_user_set() {
|
||||
let mut s = json!({ "model": "claude-opus-4-5-20251001" });
|
||||
assert!(reset_pro_to_opus_default(&mut s));
|
||||
assert_eq!(s["model"].as_str(), Some("claude-sonnet-4-6"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reset_pro_opus_default_preserves_when_user_set() {
|
||||
let mut s = json!({ "model": "claude-opus-4-5-20251001", "modelSetByUser": true });
|
||||
assert!(!reset_pro_to_opus_default(&mut s));
|
||||
assert_eq!(s["model"].as_str(), Some("claude-opus-4-5-20251001"));
|
||||
}
|
||||
|
||||
// ---- run_migrations integration ----------------------------------------
|
||||
|
||||
#[test]
|
||||
fn run_migrations_applies_chain() {
|
||||
// A Sonnet 4.5 model should end up as "sonnet" after the full chain.
|
||||
let mut s = json!({ "model": "claude-sonnet-4-5-20250929" });
|
||||
let changed = run_migrations(&mut s);
|
||||
assert!(changed);
|
||||
assert_eq!(s["model"].as_str(), Some("sonnet"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn run_migrations_returns_false_when_nothing_changes() {
|
||||
let mut s = json!({ "model": "claude-sonnet-4-6", "someOtherKey": 42 });
|
||||
assert!(!run_migrations(&mut s));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn run_migrations_handles_empty_object() {
|
||||
let mut s = json!({});
|
||||
// No model fields, no sentinel values → nothing to do.
|
||||
assert!(!run_migrations(&mut s));
|
||||
}
|
||||
}
|
||||
364
src-rust/crates/core/src/oauth_config.rs
Normal file
364
src-rust/crates/core/src/oauth_config.rs
Normal file
|
|
@ -0,0 +1,364 @@
|
|||
//! OAuth configuration for multiple environments.
|
||||
//!
|
||||
//! This module mirrors the TypeScript `src/constants/oauth.ts` and
|
||||
//! `src/services/oauth/crypto.ts` constants. It is intentionally
|
||||
//! *configuration-only* — no live network I/O except for the optional
|
||||
//! `fetch_oauth_profile` helper at the bottom.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Scope constants (mirrors constants/oauth.ts)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// The Claude.ai inference scope — required for Bearer-auth API calls.
|
||||
pub const CLAUDE_AI_INFERENCE_SCOPE: &str = "user:inference";
|
||||
|
||||
/// The profile scope — required to read account / subscription data.
|
||||
pub const CLAUDE_AI_PROFILE_SCOPE: &str = "user:profile";
|
||||
|
||||
/// Console scope — used when creating an API key via the Console flow.
|
||||
pub const CONSOLE_SCOPE: &str = "org:create_api_key";
|
||||
|
||||
/// All Claude.ai OAuth scopes (mirrors `CLAUDE_AI_OAUTH_SCOPES`).
|
||||
pub const CLAUDE_AI_OAUTH_SCOPES: &[&str] = &[
|
||||
CLAUDE_AI_PROFILE_SCOPE,
|
||||
CLAUDE_AI_INFERENCE_SCOPE,
|
||||
"user:sessions:claude_code",
|
||||
"user:mcp_servers",
|
||||
"user:file_upload",
|
||||
];
|
||||
|
||||
/// Console OAuth scopes (mirrors `CONSOLE_OAUTH_SCOPES`).
|
||||
pub const CONSOLE_OAUTH_SCOPES: &[&str] = &[CONSOLE_SCOPE, CLAUDE_AI_PROFILE_SCOPE];
|
||||
|
||||
/// Union of all scopes used during login (mirrors `ALL_OAUTH_SCOPES`).
|
||||
/// Requesting all at once lets a single login satisfy both Console and
|
||||
/// claude.ai auth paths.
|
||||
pub const ALL_OAUTH_SCOPES: &[&str] = &[
|
||||
CONSOLE_SCOPE,
|
||||
CLAUDE_AI_PROFILE_SCOPE,
|
||||
CLAUDE_AI_INFERENCE_SCOPE,
|
||||
"user:sessions:claude_code",
|
||||
"user:mcp_servers",
|
||||
"user:file_upload",
|
||||
];
|
||||
|
||||
/// Minimum scopes required for basic operation.
|
||||
pub const MINIMUM_SCOPES: &[&str] = &[CLAUDE_AI_INFERENCE_SCOPE, CLAUDE_AI_PROFILE_SCOPE];
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// OAuthConfig struct
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Full OAuth configuration for a deployment environment.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OAuthConfig {
|
||||
pub base_api_url: &'static str,
|
||||
pub console_authorize_url: &'static str,
|
||||
pub claude_ai_authorize_url: &'static str,
|
||||
/// The raw claude.ai web origin (separate from the authorize URL which
|
||||
/// may bounce through claude.com for attribution).
|
||||
pub claude_ai_origin: &'static str,
|
||||
pub token_url: &'static str,
|
||||
pub api_key_url: &'static str,
|
||||
pub roles_url: &'static str,
|
||||
pub console_success_url: &'static str,
|
||||
pub claudeai_success_url: &'static str,
|
||||
pub manual_redirect_url: &'static str,
|
||||
pub client_id: &'static str,
|
||||
pub oauth_file_suffix: &'static str,
|
||||
pub mcp_proxy_url: &'static str,
|
||||
pub mcp_proxy_path: &'static str,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Production config (mirrors PROD_OAUTH_CONFIG in oauth.ts)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub const PROD_OAUTH: OAuthConfig = OAuthConfig {
|
||||
base_api_url: "https://api.anthropic.com",
|
||||
// Routes through claude.com/cai/* for attribution, 307s to claude.ai in
|
||||
// two hops — same behaviour as the TypeScript client.
|
||||
console_authorize_url: "https://platform.claude.com/oauth/authorize",
|
||||
claude_ai_authorize_url: "https://claude.com/cai/oauth/authorize",
|
||||
claude_ai_origin: "https://claude.ai",
|
||||
token_url: "https://platform.claude.com/v1/oauth/token",
|
||||
api_key_url: "https://api.anthropic.com/api/oauth/claude_cli/create_api_key",
|
||||
roles_url: "https://api.anthropic.com/api/oauth/claude_cli/roles",
|
||||
console_success_url: "https://platform.claude.com/buy_credits?returnUrl=/oauth/code/success%3Fapp%3Dclaude-code",
|
||||
claudeai_success_url: "https://platform.claude.com/oauth/code/success?app=claude-code",
|
||||
manual_redirect_url: "https://platform.claude.com/oauth/code/callback",
|
||||
client_id: "9d1c250a-e61b-44d9-88ed-5944d1962f5e",
|
||||
oauth_file_suffix: "",
|
||||
mcp_proxy_url: "https://mcp-proxy.anthropic.com",
|
||||
mcp_proxy_path: "/v1/mcp/{server_id}",
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Staging config (mirrors STAGING_OAUTH_CONFIG — ant builds only)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub const STAGING_OAUTH: OAuthConfig = OAuthConfig {
|
||||
base_api_url: "https://api-staging.anthropic.com",
|
||||
console_authorize_url: "https://platform.staging.ant.dev/oauth/authorize",
|
||||
claude_ai_authorize_url: "https://claude-ai.staging.ant.dev/oauth/authorize",
|
||||
claude_ai_origin: "https://claude-ai.staging.ant.dev",
|
||||
token_url: "https://platform.staging.ant.dev/v1/oauth/token",
|
||||
api_key_url: "https://api-staging.anthropic.com/api/oauth/claude_cli/create_api_key",
|
||||
roles_url: "https://api-staging.anthropic.com/api/oauth/claude_cli/roles",
|
||||
console_success_url: "https://platform.staging.ant.dev/buy_credits?returnUrl=/oauth/code/success%3Fapp%3Dclaude-code",
|
||||
claudeai_success_url: "https://platform.staging.ant.dev/oauth/code/success?app=claude-code",
|
||||
manual_redirect_url: "https://platform.staging.ant.dev/oauth/code/callback",
|
||||
client_id: "22422756-60c9-4084-8eb7-27705fd5cf9a",
|
||||
oauth_file_suffix: "-staging-oauth",
|
||||
mcp_proxy_url: "https://mcp-proxy-staging.anthropic.com",
|
||||
mcp_proxy_path: "/v1/mcp/{server_id}",
|
||||
};
|
||||
|
||||
/// Client-ID Metadata Document URL for MCP OAuth (CIMD / SEP-991).
|
||||
pub const MCP_CLIENT_METADATA_URL: &str =
|
||||
"https://claude.ai/oauth/claude-code-client-metadata";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Config selection
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Return the OAuth config appropriate for the current environment.
|
||||
///
|
||||
/// Selection logic mirrors `getOauthConfigType()` in `constants/oauth.ts`:
|
||||
/// - `USER_TYPE=ant` + `USE_STAGING_OAUTH=true` → staging
|
||||
/// - anything else → production
|
||||
///
|
||||
/// Note: the `local` variant from the TypeScript code is intentionally
|
||||
/// omitted here — local dev servers are not needed in the Rust port yet.
|
||||
pub fn get_oauth_config() -> &'static OAuthConfig {
|
||||
let user_type = std::env::var("USER_TYPE").unwrap_or_default();
|
||||
if user_type == "ant" {
|
||||
let use_staging = std::env::var("USE_STAGING_OAUTH")
|
||||
.map(|v| matches!(v.as_str(), "1" | "true" | "yes"))
|
||||
.unwrap_or(false);
|
||||
if use_staging {
|
||||
return &STAGING_OAUTH;
|
||||
}
|
||||
}
|
||||
&PROD_OAUTH
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// PKCE helpers (mirrors src/services/oauth/crypto.ts)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// PKCE code-challenge / code-verifier helpers.
|
||||
pub mod pkce {
|
||||
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
/// Generate a cryptographically random code verifier (43–128 chars of
|
||||
/// Base64url characters, as required by RFC 7636).
|
||||
///
|
||||
/// Uses `getrandom` via the `rand` crate's OS RNG through the `uuid`
|
||||
/// crate's v4 generator — both already in-tree. Falls back to a
|
||||
/// time+pid mix if the OS RNG is unavailable.
|
||||
pub fn generate_code_verifier() -> String {
|
||||
// 32 random bytes → 43-char Base64url string (same as the TS impl).
|
||||
let bytes = random_bytes_32();
|
||||
URL_SAFE_NO_PAD.encode(bytes)
|
||||
}
|
||||
|
||||
/// Compute `BASE64URL(SHA256(verifier))` — the S256 code challenge.
|
||||
pub fn code_challenge(verifier: &str) -> String {
|
||||
let hash = Sha256::digest(verifier.as_bytes());
|
||||
URL_SAFE_NO_PAD.encode(hash)
|
||||
}
|
||||
|
||||
/// Generate a random state parameter (16 Base64url chars).
|
||||
pub fn generate_state() -> String {
|
||||
let bytes = random_bytes_32();
|
||||
let encoded = URL_SAFE_NO_PAD.encode(bytes);
|
||||
// Take first 43 chars for a compact state parameter
|
||||
encoded.chars().take(43).collect()
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------
|
||||
// Internal: produce 32 random bytes.
|
||||
// We derive them from a UUID v4 (which already pulls from the OS RNG
|
||||
// via the `uuid` crate) so we don't need to add a new `rand` dep.
|
||||
// ------------------------------------------------------------------
|
||||
fn random_bytes_32() -> [u8; 32] {
|
||||
// Two UUID v4 values give us 32 bytes of OS-backed randomness.
|
||||
let u1 = uuid::Uuid::new_v4();
|
||||
let u2 = uuid::Uuid::new_v4();
|
||||
let mut out = [0u8; 32];
|
||||
out[..16].copy_from_slice(u1.as_bytes());
|
||||
out[16..].copy_from_slice(u2.as_bytes());
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Token and profile types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Raw OAuth token response from the token endpoint.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TokenResponse {
|
||||
pub access_token: String,
|
||||
pub token_type: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub expires_in: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub refresh_token: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub scope: Option<String>,
|
||||
}
|
||||
|
||||
/// Slim profile fetched after token exchange.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct OAuthProfile {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub email: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub display_name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub account_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub subscription_tier: Option<String>,
|
||||
}
|
||||
|
||||
/// Fetch the OAuth profile using an access token.
|
||||
///
|
||||
/// Returns a default (all-`None`) profile on any non-success response so
|
||||
/// callers can treat a profile fetch failure as non-fatal.
|
||||
pub async fn fetch_oauth_profile(
|
||||
access_token: &str,
|
||||
api_base: &str,
|
||||
) -> anyhow::Result<OAuthProfile> {
|
||||
let client = reqwest::Client::new();
|
||||
let url = format!("{}/api/auth/oauth/profile", api_base.trim_end_matches('/'));
|
||||
|
||||
let resp = client
|
||||
.get(&url)
|
||||
.bearer_auth(access_token)
|
||||
.timeout(std::time::Duration::from_secs(10))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if resp.status().is_success() {
|
||||
let profile: OAuthProfile = resp.json().await.unwrap_or_default();
|
||||
Ok(profile)
|
||||
} else {
|
||||
// Non-fatal: return an empty profile so the caller can continue.
|
||||
Ok(OAuthProfile::default())
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Auth URL builder
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Build the OAuth authorization URL (mirrors `buildAuthUrl` in client.ts).
|
||||
pub fn build_auth_url(
|
||||
code_challenge: &str,
|
||||
state: &str,
|
||||
port: u16,
|
||||
is_manual: bool,
|
||||
login_with_claude_ai: bool,
|
||||
inference_only: bool,
|
||||
) -> String {
|
||||
let cfg = get_oauth_config();
|
||||
|
||||
let base = if login_with_claude_ai {
|
||||
cfg.claude_ai_authorize_url
|
||||
} else {
|
||||
cfg.console_authorize_url
|
||||
};
|
||||
|
||||
let redirect_uri = if is_manual {
|
||||
cfg.manual_redirect_url.to_string()
|
||||
} else {
|
||||
format!("http://localhost:{}/callback", port)
|
||||
};
|
||||
|
||||
let scopes: Vec<&str> = if inference_only {
|
||||
vec![CLAUDE_AI_INFERENCE_SCOPE]
|
||||
} else {
|
||||
ALL_OAUTH_SCOPES.to_vec()
|
||||
};
|
||||
|
||||
let scope_str = scopes.join(" ");
|
||||
|
||||
format!(
|
||||
"{}?code=true&client_id={}&response_type=code&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method=S256&state={}",
|
||||
base,
|
||||
urlencoding::encode(cfg.client_id),
|
||||
urlencoding::encode(&redirect_uri),
|
||||
urlencoding::encode(&scope_str),
|
||||
urlencoding::encode(code_challenge),
|
||||
urlencoding::encode(state),
|
||||
)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_prod_config_urls_are_https() {
|
||||
assert!(PROD_OAUTH.token_url.starts_with("https://"));
|
||||
assert!(PROD_OAUTH.api_key_url.starts_with("https://"));
|
||||
assert!(PROD_OAUTH.claude_ai_authorize_url.starts_with("https://"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_staging_config_urls_are_https() {
|
||||
assert!(STAGING_OAUTH.token_url.starts_with("https://"));
|
||||
assert!(STAGING_OAUTH.api_key_url.starts_with("https://"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pkce_code_challenge_is_base64url() {
|
||||
let verifier = pkce::generate_code_verifier();
|
||||
assert!(!verifier.is_empty());
|
||||
// Base64url characters only (no +, /, =)
|
||||
assert!(!verifier.contains('+'));
|
||||
assert!(!verifier.contains('/'));
|
||||
assert!(!verifier.contains('='));
|
||||
|
||||
let challenge = pkce::code_challenge(&verifier);
|
||||
assert!(!challenge.is_empty());
|
||||
assert!(!challenge.contains('+'));
|
||||
assert!(!challenge.contains('/'));
|
||||
assert!(!challenge.contains('='));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verifier_length_meets_rfc7636_minimum() {
|
||||
let verifier = pkce::generate_code_verifier();
|
||||
// RFC 7636 §4.1: code_verifier length ∈ [43, 128]
|
||||
assert!(
|
||||
verifier.len() >= 43,
|
||||
"verifier too short: {} chars",
|
||||
verifier.len()
|
||||
);
|
||||
assert!(verifier.len() <= 128, "verifier too long: {} chars", verifier.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_oauth_scopes_contains_inference() {
|
||||
assert!(ALL_OAUTH_SCOPES.contains(&CLAUDE_AI_INFERENCE_SCOPE));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_auth_url_contains_required_params() {
|
||||
let url = build_auth_url("challenge123", "state456", 8080, false, true, false);
|
||||
assert!(url.contains("challenge123"));
|
||||
assert!(url.contains("state456"));
|
||||
assert!(url.contains("S256"));
|
||||
assert!(url.contains("localhost"));
|
||||
}
|
||||
}
|
||||
347
src-rust/crates/core/src/output_styles.rs
Normal file
347
src-rust/crates/core/src/output_styles.rs
Normal file
|
|
@ -0,0 +1,347 @@
|
|||
//! Output style system — customises how Claude responds to the user.
|
||||
//!
|
||||
//! Styles are applied by injecting `OutputStyleDef::prompt` into the system
|
||||
//! prompt. Built-in styles are defined in code; users can add their own by
|
||||
//! placing `.md` or `.json` files in:
|
||||
//! - Global: `~/.claude/output-styles/`
|
||||
//! - Project: `.claude/output-styles/`
|
||||
//!
|
||||
//! Markdown style files have a simple structure:
|
||||
//! Line 1: `# <Label>` (heading becomes the label)
|
||||
//! Line 2: short description
|
||||
//! Remainder: the prompt text injected into the system prompt
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::Path;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A single output style definition.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct OutputStyleDef {
|
||||
/// Machine-readable identifier (e.g. `"concise"`).
|
||||
pub name: String,
|
||||
/// Human-readable label shown in picker UI (e.g. `"Concise"`).
|
||||
pub label: String,
|
||||
/// One-line description.
|
||||
pub description: String,
|
||||
/// Text injected into the system prompt when this style is active.
|
||||
/// Empty string for the default style (no extra injection).
|
||||
pub prompt: String,
|
||||
}
|
||||
|
||||
impl OutputStyleDef {
|
||||
// ---- Built-in styles ---------------------------------------------------
|
||||
|
||||
pub fn builtin_default() -> Self {
|
||||
Self {
|
||||
name: "default".to_string(),
|
||||
label: "Default".to_string(),
|
||||
description: "Standard Claude Code responses.".to_string(),
|
||||
prompt: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn builtin_concise() -> Self {
|
||||
Self {
|
||||
name: "concise".to_string(),
|
||||
label: "Concise".to_string(),
|
||||
description: "Short, direct responses with minimal explanation.".to_string(),
|
||||
prompt: "Be maximally concise. Skip preamble, summaries, and filler. \
|
||||
Lead with the answer."
|
||||
.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn builtin_explanatory() -> Self {
|
||||
Self {
|
||||
name: "explanatory".to_string(),
|
||||
label: "Explanatory".to_string(),
|
||||
description: "Thorough explanations with reasoning and alternatives.".to_string(),
|
||||
prompt: "When explaining code or concepts, be thorough and educational. \
|
||||
Include reasoning, alternatives considered, and potential pitfalls. \
|
||||
Err on the side of over-explaining."
|
||||
.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn builtin_learning() -> Self {
|
||||
Self {
|
||||
name: "learning".to_string(),
|
||||
label: "Learning".to_string(),
|
||||
description: "Pedagogical mode — explains patterns and decisions.".to_string(),
|
||||
prompt: "This user is learning. Explain concepts as you implement them. \
|
||||
Point out patterns, best practices, and why you made each decision. \
|
||||
Use analogies when helpful."
|
||||
.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Built-ins
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Return all built-in output styles in display order.
|
||||
pub fn builtin_styles() -> Vec<OutputStyleDef> {
|
||||
vec![
|
||||
OutputStyleDef::builtin_default(),
|
||||
OutputStyleDef::builtin_concise(),
|
||||
OutputStyleDef::builtin_explanatory(),
|
||||
OutputStyleDef::builtin_learning(),
|
||||
]
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Loading from disk
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Load user-defined output styles from a directory.
|
||||
///
|
||||
/// Supported file formats:
|
||||
/// - `.md` — Markdown: `# Label\ndescription\n\nprompt text…`
|
||||
/// - `.json` — JSON: `{ "name": "…", "label": "…", "description": "…", "prompt": "…" }`
|
||||
///
|
||||
/// Files that cannot be parsed are silently skipped.
|
||||
pub fn load_output_styles_dir(styles_dir: &Path) -> Vec<OutputStyleDef> {
|
||||
if !styles_dir.exists() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let entries = match std::fs::read_dir(styles_dir) {
|
||||
Ok(e) => e,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
let mut styles = Vec::new();
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
|
||||
if ext == "md" || ext == "json" {
|
||||
if let Some(style) = load_style_file(&path) {
|
||||
styles.push(style);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort alphabetically so the list is deterministic.
|
||||
styles.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
styles
|
||||
}
|
||||
|
||||
fn load_style_file(path: &Path) -> Option<OutputStyleDef> {
|
||||
let content = std::fs::read_to_string(path).ok()?;
|
||||
let stem = path.file_stem()?.to_string_lossy().into_owned();
|
||||
|
||||
if path.extension().and_then(|e| e.to_str()) == Some("json") {
|
||||
// Try deserialising directly; fall back to inserting the stem as name.
|
||||
let mut def: OutputStyleDef = serde_json::from_str(&content).ok()?;
|
||||
if def.name.is_empty() {
|
||||
def.name = stem;
|
||||
}
|
||||
return Some(def);
|
||||
}
|
||||
|
||||
// Markdown format:
|
||||
// Line 1: # Label (optional leading `#` and whitespace)
|
||||
// Line 2: description (short, plain text)
|
||||
// Lines 3+: prompt text (everything after the blank / second line)
|
||||
let mut lines = content.lines();
|
||||
|
||||
let raw_label = lines.next().unwrap_or("").trim().to_string();
|
||||
let label = raw_label.trim_start_matches('#').trim().to_string();
|
||||
let label = if label.is_empty() { stem.clone() } else { label };
|
||||
|
||||
let description = lines
|
||||
.next()
|
||||
.map(|l| l.trim().to_string())
|
||||
.unwrap_or_default();
|
||||
|
||||
// Collect remaining lines as the prompt, trimming leading blank lines.
|
||||
let prompt_lines: Vec<&str> = lines.collect();
|
||||
let prompt = prompt_lines.join("\n").trim().to_string();
|
||||
|
||||
Some(OutputStyleDef {
|
||||
name: stem,
|
||||
label,
|
||||
description,
|
||||
prompt,
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Aggregated access
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Return all styles available for `config_dir`:
|
||||
/// built-ins first, then styles from `<config_dir>/output-styles/`.
|
||||
///
|
||||
/// `config_dir` is typically `~/.claude`.
|
||||
pub fn all_styles(config_dir: &Path) -> Vec<OutputStyleDef> {
|
||||
let mut styles = builtin_styles();
|
||||
let user_dir = config_dir.join("output-styles");
|
||||
styles.extend(load_output_styles_dir(&user_dir));
|
||||
styles
|
||||
}
|
||||
|
||||
/// Find a style by its `name` field.
|
||||
pub fn find_style<'a>(styles: &'a [OutputStyleDef], name: &str) -> Option<&'a OutputStyleDef> {
|
||||
styles.iter().find(|s| s.name == name)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Write as IoWrite;
|
||||
use tempfile::TempDir;
|
||||
|
||||
// ---- builtin_styles ----------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn builtin_styles_non_empty() {
|
||||
assert!(!builtin_styles().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn builtin_styles_have_unique_names() {
|
||||
let styles = builtin_styles();
|
||||
let mut seen = std::collections::HashSet::new();
|
||||
for s in &styles {
|
||||
assert!(seen.insert(&s.name), "duplicate style name: {}", s.name);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn builtin_default_has_empty_prompt() {
|
||||
let def = OutputStyleDef::builtin_default();
|
||||
assert!(def.prompt.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn builtin_non_default_have_prompts() {
|
||||
for s in builtin_styles() {
|
||||
if s.name != "default" {
|
||||
assert!(
|
||||
!s.prompt.is_empty(),
|
||||
"style '{}' should have a non-empty prompt",
|
||||
s.name
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---- find_style --------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn find_style_by_name() {
|
||||
let styles = builtin_styles();
|
||||
let found = find_style(&styles, "concise");
|
||||
assert!(found.is_some());
|
||||
assert_eq!(found.unwrap().name, "concise");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_style_missing() {
|
||||
let styles = builtin_styles();
|
||||
assert!(find_style(&styles, "nonexistent-xyz").is_none());
|
||||
}
|
||||
|
||||
// ---- load_output_styles_dir (markdown) ---------------------------------
|
||||
|
||||
fn write_file(dir: &TempDir, name: &str, content: &str) {
|
||||
let path = dir.path().join(name);
|
||||
let mut f = std::fs::File::create(path).unwrap();
|
||||
f.write_all(content.as_bytes()).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_markdown_style() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
write_file(
|
||||
&dir,
|
||||
"terse.md",
|
||||
"# Terse\nVery short answers.\n\nOne sentence per response.",
|
||||
);
|
||||
let styles = load_output_styles_dir(dir.path());
|
||||
assert_eq!(styles.len(), 1);
|
||||
let s = &styles[0];
|
||||
assert_eq!(s.name, "terse");
|
||||
assert_eq!(s.label, "Terse");
|
||||
assert_eq!(s.description, "Very short answers.");
|
||||
assert_eq!(s.prompt, "One sentence per response.");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_json_style() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
write_file(
|
||||
&dir,
|
||||
"formal.json",
|
||||
r#"{"name":"formal","label":"Formal","description":"Formal tone.","prompt":"Use formal language."}"#,
|
||||
);
|
||||
let styles = load_output_styles_dir(dir.path());
|
||||
assert_eq!(styles.len(), 1);
|
||||
let s = &styles[0];
|
||||
assert_eq!(s.name, "formal");
|
||||
assert_eq!(s.label, "Formal");
|
||||
assert_eq!(s.prompt, "Use formal language.");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_skips_unknown_extensions() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
write_file(&dir, "ignore.txt", "should be skipped");
|
||||
let styles = load_output_styles_dir(dir.path());
|
||||
assert!(styles.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_non_existent_dir_returns_empty() {
|
||||
use std::path::PathBuf;
|
||||
let styles = load_output_styles_dir(&PathBuf::from("/nonexistent/path/xyz"));
|
||||
assert!(styles.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_multiple_styles_sorted() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
write_file(&dir, "zebra.md", "# Zebra\nZ style.\n\nZ prompt.");
|
||||
write_file(&dir, "apple.md", "# Apple\nA style.\n\nA prompt.");
|
||||
let styles = load_output_styles_dir(dir.path());
|
||||
assert_eq!(styles[0].name, "apple");
|
||||
assert_eq!(styles[1].name, "zebra");
|
||||
}
|
||||
|
||||
// ---- all_styles --------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn all_styles_includes_builtins() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
// no output-styles subdir → only built-ins
|
||||
let styles = all_styles(dir.path());
|
||||
assert!(styles.iter().any(|s| s.name == "default"));
|
||||
assert!(styles.iter().any(|s| s.name == "concise"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn all_styles_merges_user_styles() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let output_styles_dir = dir.path().join("output-styles");
|
||||
std::fs::create_dir_all(&output_styles_dir).unwrap();
|
||||
|
||||
// Write a user style file.
|
||||
let mut f = std::fs::File::create(output_styles_dir.join("pirate.md")).unwrap();
|
||||
f.write_all(b"# Pirate\nSpeak like a pirate.\n\nArrr matey!").unwrap();
|
||||
|
||||
let styles = all_styles(dir.path());
|
||||
assert!(styles.iter().any(|s| s.name == "pirate"));
|
||||
// Built-ins still present.
|
||||
assert!(styles.iter().any(|s| s.name == "default"));
|
||||
}
|
||||
}
|
||||
526
src-rust/crates/core/src/system_prompt.rs
Normal file
526
src-rust/crates/core/src/system_prompt.rs
Normal file
|
|
@ -0,0 +1,526 @@
|
|||
//! Modular system prompt assembly with caching support.
|
||||
//!
|
||||
//! Mirrors the TypeScript `systemPromptSections.ts` / `prompts.ts` architecture:
|
||||
//! cacheable (static) sections are placed before `SYSTEM_PROMPT_DYNAMIC_BOUNDARY`;
|
||||
//! volatile, session-specific sections follow it.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
use std::collections::HashMap;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Dynamic boundary marker
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Marker that splits the cached vs dynamic parts of the system prompt.
|
||||
/// Everything before this marker can be prompt-cached by the API.
|
||||
/// Matches the TypeScript constant `SYSTEM_PROMPT_DYNAMIC_BOUNDARY`.
|
||||
pub const SYSTEM_PROMPT_DYNAMIC_BOUNDARY: &str = "__SYSTEM_PROMPT_DYNAMIC_BOUNDARY__";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Section cache (mirrors bootstrap/state.ts systemPromptSectionCache)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn section_cache() -> &'static Mutex<HashMap<String, Option<String>>> {
|
||||
static CACHE: OnceLock<Mutex<HashMap<String, Option<String>>>> = OnceLock::new();
|
||||
CACHE.get_or_init(|| Mutex::new(HashMap::new()))
|
||||
}
|
||||
|
||||
/// Clear all cached system prompt sections (called on /clear and /compact).
|
||||
pub fn clear_system_prompt_sections() {
|
||||
if let Ok(mut cache) = section_cache().lock() {
|
||||
cache.clear();
|
||||
}
|
||||
}
|
||||
|
||||
/// A single named section of the system prompt.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SystemPromptSection {
|
||||
/// Identifier used for cache lookups and invalidation.
|
||||
pub tag: &'static str,
|
||||
/// Computed content (None means the section is absent/disabled).
|
||||
pub content: Option<String>,
|
||||
/// If true the section is volatile and must not be prompt-cached.
|
||||
pub cache_break: bool,
|
||||
}
|
||||
|
||||
impl SystemPromptSection {
|
||||
/// Create a memoizable (cacheable) section.
|
||||
pub fn cached(tag: &'static str, content: impl Into<String>) -> Self {
|
||||
Self { tag, content: Some(content.into()), cache_break: false }
|
||||
}
|
||||
|
||||
/// Create a volatile section that re-evaluates every turn.
|
||||
/// Passing `None` for content means the section is absent this turn.
|
||||
pub fn uncached(tag: &'static str, content: Option<impl Into<String>>) -> Self {
|
||||
Self {
|
||||
tag,
|
||||
content: content.map(|c| c.into()),
|
||||
cache_break: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Output style
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Output styles that affect the system prompt.
|
||||
/// Serialised as lowercase strings to match settings.json.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum OutputStyle {
|
||||
#[default]
|
||||
Default,
|
||||
Explanatory,
|
||||
Learning,
|
||||
Concise,
|
||||
Formal,
|
||||
Casual,
|
||||
}
|
||||
|
||||
impl OutputStyle {
|
||||
/// Returns the system-prompt suffix for this style, or `None` for Default.
|
||||
pub fn prompt_suffix(self) -> Option<&'static str> {
|
||||
match self {
|
||||
OutputStyle::Explanatory => Some(
|
||||
"When explaining code or concepts, be thorough and educational. \
|
||||
Include reasoning, alternatives considered, and potential pitfalls. \
|
||||
Err on the side of over-explaining.",
|
||||
),
|
||||
OutputStyle::Learning => Some(
|
||||
"This user is learning. Explain concepts as you implement them. \
|
||||
Point out patterns, best practices, and why you made each decision. \
|
||||
Use analogies when helpful.",
|
||||
),
|
||||
OutputStyle::Concise => Some(
|
||||
"Be maximally concise. Skip preamble, summaries, and filler. \
|
||||
Lead with the answer. One sentence is better than three.",
|
||||
),
|
||||
OutputStyle::Formal => Some(
|
||||
"Maintain a formal, professional tone. Use precise technical language.",
|
||||
),
|
||||
OutputStyle::Casual => Some("Use a casual, conversational tone."),
|
||||
OutputStyle::Default => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse from a string (case-insensitive).
|
||||
pub fn from_str(s: &str) -> Self {
|
||||
match s.to_lowercase().as_str() {
|
||||
"explanatory" => Self::Explanatory,
|
||||
"learning" => Self::Learning,
|
||||
"concise" => Self::Concise,
|
||||
"formal" => Self::Formal,
|
||||
"casual" => Self::Casual,
|
||||
_ => Self::Default,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// System prompt prefix variants
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Which entrypoint context Claude Code is running in.
|
||||
/// Determines the opening attribution line of the system prompt.
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum SystemPromptPrefix {
|
||||
/// Standard interactive CLI session.
|
||||
Cli,
|
||||
/// Running as a sub-agent spawned by the Claude Agent SDK.
|
||||
Sdk,
|
||||
/// The CLI preset running within the Agent SDK
|
||||
/// (non-interactive + append_system_prompt set).
|
||||
SdkPreset,
|
||||
/// Running on Vertex AI.
|
||||
Vertex,
|
||||
/// Running on AWS Bedrock.
|
||||
Bedrock,
|
||||
/// Remote / headless CCR session.
|
||||
Remote,
|
||||
}
|
||||
|
||||
impl SystemPromptPrefix {
|
||||
/// Detect from environment variables, mirroring `getCLISyspromptPrefix`.
|
||||
pub fn detect(is_non_interactive: bool, has_append_system_prompt: bool) -> Self {
|
||||
// Vertex: always uses the default "Claude Code" prefix.
|
||||
if std::env::var("ANTHROPIC_VERTEX_PROJECT_ID").is_ok()
|
||||
|| std::env::var("CLOUD_ML_PROJECT_ID").is_ok()
|
||||
{
|
||||
return Self::Vertex;
|
||||
}
|
||||
|
||||
if std::env::var("AWS_BEDROCK_MODEL_ID").is_ok() {
|
||||
return Self::Bedrock;
|
||||
}
|
||||
|
||||
if std::env::var("CLAUDE_CODE_REMOTE").is_ok() {
|
||||
return Self::Remote;
|
||||
}
|
||||
|
||||
// Non-interactive mode maps to SDK variants (matches TS getCLISyspromptPrefix).
|
||||
if is_non_interactive {
|
||||
if has_append_system_prompt {
|
||||
return Self::SdkPreset;
|
||||
}
|
||||
return Self::Sdk;
|
||||
}
|
||||
|
||||
Self::Cli
|
||||
}
|
||||
|
||||
/// The opening attribution string for this prefix variant.
|
||||
pub fn attribution_text(self) -> &'static str {
|
||||
match self {
|
||||
Self::Cli | Self::Vertex | Self::Bedrock | Self::Remote => {
|
||||
"You are Claude Code, Anthropic's official CLI for Claude."
|
||||
}
|
||||
Self::SdkPreset => {
|
||||
"You are Claude Code, Anthropic's official CLI for Claude, \
|
||||
running within the Claude Agent SDK."
|
||||
}
|
||||
Self::Sdk => {
|
||||
"You are a Claude agent, built on Anthropic's Claude Agent SDK."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Build options
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// All options controlling what goes into the assembled system prompt.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct SystemPromptOptions {
|
||||
/// Override auto-detected prefix.
|
||||
pub prefix: Option<SystemPromptPrefix>,
|
||||
/// Whether the session is non-interactive (SDK / pipe mode).
|
||||
pub is_non_interactive: bool,
|
||||
/// Whether --append-system-prompt is set (affects prefix detection).
|
||||
pub has_append_system_prompt: bool,
|
||||
/// Output style to inject.
|
||||
pub output_style: OutputStyle,
|
||||
/// Absolute path to the working directory (injected as dynamic section).
|
||||
pub working_directory: Option<String>,
|
||||
/// Pre-built memory content from memdir (injected as dynamic section).
|
||||
pub memory_content: String,
|
||||
/// Custom system prompt (--system-prompt flag or settings).
|
||||
pub custom_system_prompt: Option<String>,
|
||||
/// Additional text appended after everything else (--append-system-prompt).
|
||||
pub append_system_prompt: Option<String>,
|
||||
/// If true and `custom_system_prompt` is set, the entire default prompt is
|
||||
/// replaced — only the custom text + dynamic boundary are emitted.
|
||||
pub replace_system_prompt: bool,
|
||||
/// Inject the coordinator-mode section.
|
||||
pub coordinator_mode: bool,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Main assembly function
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Build the complete system prompt string.
|
||||
///
|
||||
/// The returned string contains `SYSTEM_PROMPT_DYNAMIC_BOUNDARY` as an
|
||||
/// internal marker. Callers (e.g. `buildSystemPromptBlocks` in cc-query)
|
||||
/// split on this marker to determine which portions are eligible for
|
||||
/// Anthropic prompt-caching.
|
||||
pub fn build_system_prompt(opts: &SystemPromptOptions) -> String {
|
||||
// Replace mode: skip all default sections.
|
||||
if opts.replace_system_prompt {
|
||||
if let Some(custom) = &opts.custom_system_prompt {
|
||||
return format!("{}\n\n{}", custom, SYSTEM_PROMPT_DYNAMIC_BOUNDARY);
|
||||
}
|
||||
}
|
||||
|
||||
let prefix = opts
|
||||
.prefix
|
||||
.unwrap_or_else(|| {
|
||||
SystemPromptPrefix::detect(
|
||||
opts.is_non_interactive,
|
||||
opts.has_append_system_prompt,
|
||||
)
|
||||
});
|
||||
|
||||
let mut parts: Vec<String> = Vec::new();
|
||||
|
||||
// ------------------------------------------------------------------ //
|
||||
// CACHEABLE sections (before the dynamic boundary) //
|
||||
// ------------------------------------------------------------------ //
|
||||
|
||||
// 1. Attribution header
|
||||
parts.push(prefix.attribution_text().to_string());
|
||||
|
||||
// 2. Core capabilities
|
||||
parts.push(CORE_CAPABILITIES.to_string());
|
||||
|
||||
// 3. Tool use guidelines
|
||||
parts.push(TOOL_USE_GUIDELINES.to_string());
|
||||
|
||||
// 4. Executing actions with care
|
||||
parts.push(ACTIONS_SECTION.to_string());
|
||||
|
||||
// 5. Safety guidelines
|
||||
parts.push(SAFETY_GUIDELINES.to_string());
|
||||
|
||||
// 6. Cyber-risk instruction (owned by safeguards — do not edit)
|
||||
parts.push(CYBER_RISK_INSTRUCTION.to_string());
|
||||
|
||||
// 7. Output style (cacheable when non-Default; its content is stable)
|
||||
if let Some(style_text) = opts.output_style.prompt_suffix() {
|
||||
parts.push(format!("\n## Output Style\n{}", style_text));
|
||||
}
|
||||
|
||||
// 8. Coordinator mode (cacheable: content is constant)
|
||||
if opts.coordinator_mode {
|
||||
parts.push(COORDINATOR_SYSTEM_PROMPT.to_string());
|
||||
}
|
||||
|
||||
// 9. Custom system prompt addition (appended to cacheable block)
|
||||
if let Some(custom) = &opts.custom_system_prompt {
|
||||
parts.push(format!(
|
||||
"\n<custom_instructions>\n{}\n</custom_instructions>",
|
||||
custom
|
||||
));
|
||||
}
|
||||
|
||||
// Dynamic boundary marker
|
||||
parts.push(SYSTEM_PROMPT_DYNAMIC_BOUNDARY.to_string());
|
||||
|
||||
// ------------------------------------------------------------------ //
|
||||
// DYNAMIC / UNCACHEABLE sections (after the boundary) //
|
||||
// ------------------------------------------------------------------ //
|
||||
|
||||
// 10. Working directory
|
||||
if let Some(cwd) = &opts.working_directory {
|
||||
parts.push(format!("\n<working_directory>{}</working_directory>", cwd));
|
||||
}
|
||||
|
||||
// 11. Memory injection (from memdir)
|
||||
if !opts.memory_content.is_empty() {
|
||||
parts.push(format!(
|
||||
"\n<memory>\n{}\n</memory>",
|
||||
opts.memory_content
|
||||
));
|
||||
}
|
||||
|
||||
// 12. Appended system prompt (--append-system-prompt)
|
||||
if let Some(append) = &opts.append_system_prompt {
|
||||
parts.push(format!("\n{}", append));
|
||||
}
|
||||
|
||||
parts.join("\n")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Static system prompt sections
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const CORE_CAPABILITIES: &str = r#"
|
||||
## Capabilities
|
||||
|
||||
You have access to powerful tools for software engineering tasks:
|
||||
- **Read/Write files**: Read any file, write new files, edit existing files with precise diffs
|
||||
- **Execute commands**: Run bash commands, PowerShell scripts, background processes
|
||||
- **Search**: Glob patterns, regex grep, web search, file content search
|
||||
- **Web**: Fetch URLs, search the internet
|
||||
- **Agents**: Spawn parallel sub-agents for complex multi-step work
|
||||
- **Memory**: Persistent notes across sessions via the memory system
|
||||
- **MCP servers**: Connect to external tools and APIs via Model Context Protocol
|
||||
- **Jupyter notebooks**: Read and edit notebook cells
|
||||
|
||||
## How to approach tasks
|
||||
|
||||
1. **Understand before acting**: Read relevant files before making changes
|
||||
2. **Minimal changes**: Only modify what's needed. Don't refactor unrequested code.
|
||||
3. **Verify**: Check your work with tests or by reading the result
|
||||
4. **Communicate blockers**: If stuck, ask the user rather than guessing
|
||||
"#;
|
||||
|
||||
const TOOL_USE_GUIDELINES: &str = r#"
|
||||
## Tool use guidelines
|
||||
|
||||
- Use dedicated tools (Read, Edit, Glob, Grep) instead of bash equivalents
|
||||
- For searches, prefer Grep over `grep`; prefer Glob over `find`
|
||||
- Parallelize independent tool calls in a single response
|
||||
- For file edits: always read the file first, then make targeted edits
|
||||
- Bash commands timeout after 2 minutes; use background mode for long operations
|
||||
"#;
|
||||
|
||||
const ACTIONS_SECTION: &str = r#"
|
||||
## Executing actions with care
|
||||
|
||||
Carefully consider the reversibility and blast radius of actions. For actions
|
||||
that are hard to reverse, affect shared systems, or could be risky or
|
||||
destructive, check with the user before proceeding. Authorization stands for
|
||||
the scope specified, not beyond. Match the scope of your actions to what was
|
||||
actually requested.
|
||||
"#;
|
||||
|
||||
const SAFETY_GUIDELINES: &str = r#"
|
||||
## Safety guidelines
|
||||
|
||||
- Never delete files without explicit user confirmation
|
||||
- Don't modify protected files (.gitconfig, .bashrc, .zshrc, .mcp.json, .claude.json)
|
||||
- Be careful with destructive operations (rm -rf, DROP TABLE, etc.)
|
||||
- Don't commit secrets, credentials, or API keys
|
||||
- For ambiguous destructive actions, ask before proceeding
|
||||
"#;
|
||||
|
||||
/// Security / cyber-risk instruction.
|
||||
/// Owned by the safeguards team — do not modify the content here.
|
||||
const CYBER_RISK_INSTRUCTION: &str = r#"
|
||||
## Security
|
||||
|
||||
You are authorized to assist with security research, CTF challenges, penetration testing
|
||||
with explicit authorization, defensive security, and educational security content. Do not
|
||||
assist with creating malware, unauthorized access, denial-of-service attacks, or any
|
||||
destructive security techniques without clear legitimate purpose.
|
||||
"#;
|
||||
|
||||
const COORDINATOR_SYSTEM_PROMPT: &str = r#"
|
||||
## Coordinator Mode
|
||||
|
||||
You are operating as an orchestrator. Spawn parallel worker agents using the Agent tool.
|
||||
Each worker prompt must be fully self-contained. Synthesize findings before delegating
|
||||
follow-up work. Use TaskCreate/TaskUpdate to track parallel work.
|
||||
"#;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn default_opts() -> SystemPromptOptions {
|
||||
SystemPromptOptions::default()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_prompt_contains_boundary() {
|
||||
let prompt = build_system_prompt(&default_opts());
|
||||
assert!(
|
||||
prompt.contains(SYSTEM_PROMPT_DYNAMIC_BOUNDARY),
|
||||
"System prompt must contain the dynamic boundary marker"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_prompt_contains_attribution() {
|
||||
let prompt = build_system_prompt(&default_opts());
|
||||
assert!(prompt.contains("Claude Code"), "Default prompt must contain attribution");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_replace_system_prompt() {
|
||||
let opts = SystemPromptOptions {
|
||||
custom_system_prompt: Some("Custom only.".to_string()),
|
||||
replace_system_prompt: true,
|
||||
..Default::default()
|
||||
};
|
||||
let prompt = build_system_prompt(&opts);
|
||||
assert!(prompt.starts_with("Custom only."));
|
||||
assert!(!prompt.contains("Capabilities"));
|
||||
assert!(prompt.contains(SYSTEM_PROMPT_DYNAMIC_BOUNDARY));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_working_directory_in_dynamic_section() {
|
||||
let opts = SystemPromptOptions {
|
||||
working_directory: Some("/home/user/project".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
let prompt = build_system_prompt(&opts);
|
||||
let boundary_pos = prompt.find(SYSTEM_PROMPT_DYNAMIC_BOUNDARY).unwrap();
|
||||
let cwd_pos = prompt.find("/home/user/project").unwrap();
|
||||
assert!(
|
||||
cwd_pos > boundary_pos,
|
||||
"Working directory must appear after the dynamic boundary"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_content_in_dynamic_section() {
|
||||
let opts = SystemPromptOptions {
|
||||
memory_content: "- [test.md](test.md) — a test memory".to_string(),
|
||||
..Default::default()
|
||||
};
|
||||
let prompt = build_system_prompt(&opts);
|
||||
let boundary_pos = prompt.find(SYSTEM_PROMPT_DYNAMIC_BOUNDARY).unwrap();
|
||||
let mem_pos = prompt.find("test.md").unwrap();
|
||||
assert!(
|
||||
mem_pos > boundary_pos,
|
||||
"Memory content must appear after the dynamic boundary"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_output_style_concise() {
|
||||
let opts = SystemPromptOptions {
|
||||
output_style: OutputStyle::Concise,
|
||||
..Default::default()
|
||||
};
|
||||
let prompt = build_system_prompt(&opts);
|
||||
assert!(prompt.contains("maximally concise"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_output_style_default_has_no_suffix() {
|
||||
let opts = SystemPromptOptions {
|
||||
output_style: OutputStyle::Default,
|
||||
..Default::default()
|
||||
};
|
||||
let prompt = build_system_prompt(&opts);
|
||||
// None of the style suffixes should appear
|
||||
assert!(!prompt.contains("maximally concise"));
|
||||
assert!(!prompt.contains("This user is learning"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_coordinator_mode_section() {
|
||||
let opts = SystemPromptOptions {
|
||||
coordinator_mode: true,
|
||||
..Default::default()
|
||||
};
|
||||
let prompt = build_system_prompt(&opts);
|
||||
assert!(prompt.contains("Coordinator Mode"));
|
||||
assert!(prompt.contains("orchestrator"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_output_style_from_str() {
|
||||
assert_eq!(OutputStyle::from_str("concise"), OutputStyle::Concise);
|
||||
assert_eq!(OutputStyle::from_str("FORMAL"), OutputStyle::Formal);
|
||||
assert_eq!(OutputStyle::from_str("unknown"), OutputStyle::Default);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sdk_prefix_non_interactive_no_append() {
|
||||
let prefix = SystemPromptPrefix::detect(true, false);
|
||||
assert_eq!(prefix, SystemPromptPrefix::Sdk);
|
||||
assert!(prefix.attribution_text().contains("Claude agent"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sdk_preset_prefix_non_interactive_with_append() {
|
||||
let prefix = SystemPromptPrefix::detect(true, true);
|
||||
assert_eq!(prefix, SystemPromptPrefix::SdkPreset);
|
||||
assert!(prefix.attribution_text().contains("Claude Agent SDK"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clear_section_cache() {
|
||||
// Populate cache then clear it — should not panic.
|
||||
{
|
||||
let mut cache = section_cache().lock().unwrap();
|
||||
cache.insert("test_section".to_string(), Some("content".to_string()));
|
||||
}
|
||||
clear_system_prompt_sections();
|
||||
let cache = section_cache().lock().unwrap();
|
||||
assert!(cache.is_empty());
|
||||
}
|
||||
}
|
||||
682
src-rust/crates/core/src/team_memory_sync.rs
Normal file
682
src-rust/crates/core/src/team_memory_sync.rs
Normal file
|
|
@ -0,0 +1,682 @@
|
|||
//! Team memory synchronization with claude.ai API.
|
||||
//!
|
||||
//! Implements delta push (only changed files) with ETag-based optimistic
|
||||
//! concurrency and greedy bin-packing of changed entries into batches that
|
||||
//! fit within the server's PUT body limit.
|
||||
//!
|
||||
//! Pull is server-wins: remote content overwrites local files unconditionally.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Sha256, Digest};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Constants
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Maximum bytes per local file accepted for sync (250 KB)
|
||||
const MAX_FILE_SIZE_BYTES: usize = 250 * 1024;
|
||||
|
||||
/// Maximum serialized bytes per PUT request body (200 KB)
|
||||
const MAX_PUT_BODY_BYTES: usize = 200 * 1024;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Data types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Persisted per-repo sync state (stored alongside local team-memory files).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct SyncState {
|
||||
/// ETag returned by the last successful GET or PUT.
|
||||
pub last_known_etag: Option<String>,
|
||||
/// Per-key server-side checksums (`"sha256:<hex>"`).
|
||||
/// Used to diff local vs remote without re-uploading unchanged entries.
|
||||
pub server_checksums: HashMap<String, String>,
|
||||
/// Server-enforced max_entries from a prior 413 response.
|
||||
pub server_max_entries: Option<usize>,
|
||||
}
|
||||
|
||||
/// A single team-memory entry (one markdown file).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TeamMemoryEntry {
|
||||
/// Relative file path (forward-slash separated, e.g. `"MEMORY.md"`).
|
||||
pub key: String,
|
||||
/// UTF-8 file content (typically Markdown).
|
||||
pub content: String,
|
||||
/// `"sha256:<hex>"` of the content.
|
||||
pub checksum: String,
|
||||
}
|
||||
|
||||
/// Server response shape for GET `/api/claude_code/team_memory`.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct TeamMemoryData {
|
||||
pub entries: Vec<TeamMemoryEntry>,
|
||||
pub etag: Option<String>,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Checksum helper
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Compute `"sha256:<lowercase hex>"` of a string.
|
||||
pub fn content_checksum(content: &str) -> String {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(content.as_bytes());
|
||||
format!("sha256:{}", hex::encode(hasher.finalize()))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Path security validation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Reject paths that could escape the team-memory directory.
|
||||
///
|
||||
/// Checks performed (mirroring the TypeScript `securePath` validation):
|
||||
/// - No null bytes
|
||||
/// - No URL-encoded traversal sequences (`%2e`, `%2f`, case-insensitive)
|
||||
/// - No backslashes
|
||||
/// - Not an absolute path (Unix `/` or Windows `C:` style)
|
||||
/// - No `..` components
|
||||
pub fn validate_memory_path(path: &str) -> Result<()> {
|
||||
if path.contains('\0') {
|
||||
anyhow::bail!("Path contains null bytes: {:?}", path);
|
||||
}
|
||||
let lower = path.to_ascii_lowercase();
|
||||
if lower.contains("%2e") || lower.contains("%2f") {
|
||||
anyhow::bail!("Path contains URL-encoded traversal sequences: {:?}", path);
|
||||
}
|
||||
if path.contains('\\') {
|
||||
anyhow::bail!("Path contains backslashes: {:?}", path);
|
||||
}
|
||||
if path.starts_with('/') {
|
||||
anyhow::bail!("Absolute Unix paths not allowed: {:?}", path);
|
||||
}
|
||||
// Windows-style absolute path: e.g. "C:" or "c:"
|
||||
if path.len() >= 2 {
|
||||
let mut chars = path.chars();
|
||||
let first = chars.next().unwrap();
|
||||
if first.is_ascii_alphabetic() && chars.next() == Some(':') {
|
||||
anyhow::bail!("Absolute Windows paths not allowed: {:?}", path);
|
||||
}
|
||||
}
|
||||
if path.split('/').any(|component| component == "..") {
|
||||
anyhow::bail!("Path traversal not allowed: {:?}", path);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TeamMemorySync
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Drives pull and push against the claude.ai team-memory API.
|
||||
pub struct TeamMemorySync {
|
||||
/// Base URL of the API, e.g. `"https://claude.ai"`.
|
||||
api_base: String,
|
||||
/// Repo identifier sent as a query parameter.
|
||||
repo: String,
|
||||
/// Bearer token for authentication.
|
||||
token: String,
|
||||
/// Local directory that mirrors the server's key namespace.
|
||||
team_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl TeamMemorySync {
|
||||
pub fn new(api_base: String, repo: String, token: String, team_dir: PathBuf) -> Self {
|
||||
Self { api_base, repo, token, team_dir }
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Pull
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
/// Pull all entries from the server. Server wins: overwrites local files.
|
||||
///
|
||||
/// Updates `state.last_known_etag` and `state.server_checksums` on success.
|
||||
/// Returns `Ok(())` on HTTP 404 (no remote data yet).
|
||||
pub async fn pull(&self, state: &mut SyncState) -> Result<()> {
|
||||
let client = reqwest::Client::new();
|
||||
let url = format!(
|
||||
"{}/api/claude_code/team_memory?repo={}",
|
||||
self.api_base,
|
||||
urlencoding::encode(&self.repo),
|
||||
);
|
||||
|
||||
let response = client
|
||||
.get(&url)
|
||||
.bearer_auth(&self.token)
|
||||
.send()
|
||||
.await
|
||||
.context("team memory pull: HTTP request failed")?;
|
||||
|
||||
let http_status = response.status();
|
||||
|
||||
if http_status.as_u16() == 404 {
|
||||
return Ok(()); // No remote data yet
|
||||
}
|
||||
|
||||
if !http_status.is_success() {
|
||||
anyhow::bail!("team memory pull failed with status {}", http_status);
|
||||
}
|
||||
|
||||
// Capture ETag before consuming the response body
|
||||
if let Some(etag) = response
|
||||
.headers()
|
||||
.get("etag")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
{
|
||||
state.last_known_etag = Some(etag.to_string());
|
||||
}
|
||||
|
||||
let data: TeamMemoryData = response
|
||||
.json()
|
||||
.await
|
||||
.context("team memory pull: failed to parse response JSON")?;
|
||||
|
||||
state.server_checksums.clear();
|
||||
|
||||
for entry in &data.entries {
|
||||
validate_memory_path(&entry.key)
|
||||
.with_context(|| format!("server returned unsafe path: {:?}", entry.key))?;
|
||||
|
||||
state
|
||||
.server_checksums
|
||||
.insert(entry.key.clone(), entry.checksum.clone());
|
||||
|
||||
let local_path = self.team_dir.join(&entry.key);
|
||||
if let Some(parent) = local_path.parent() {
|
||||
tokio::fs::create_dir_all(parent)
|
||||
.await
|
||||
.with_context(|| format!("create_dir_all for {:?}", parent))?;
|
||||
}
|
||||
|
||||
if entry.content.len() <= MAX_FILE_SIZE_BYTES {
|
||||
tokio::fs::write(&local_path, &entry.content)
|
||||
.await
|
||||
.with_context(|| format!("writing {:?}", local_path))?;
|
||||
}
|
||||
// Files exceeding MAX_FILE_SIZE_BYTES are silently skipped (same behaviour as push)
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Push
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
/// Push local changes to the server using delta upload.
|
||||
///
|
||||
/// Only entries whose local checksum differs from `state.server_checksums`
|
||||
/// are uploaded. Changed entries are packed into batches ≤ `MAX_PUT_BODY_BYTES`.
|
||||
pub async fn push(&self, state: &mut SyncState) -> Result<()> {
|
||||
let local_entries = self
|
||||
.scan_local_files()
|
||||
.await
|
||||
.context("team memory push: scanning local files")?;
|
||||
|
||||
// Delta: entries where local hash ≠ last-known server hash
|
||||
let changed: Vec<TeamMemoryEntry> = local_entries
|
||||
.into_iter()
|
||||
.filter(|entry| {
|
||||
state
|
||||
.server_checksums
|
||||
.get(&entry.key)
|
||||
.map(|s| s.as_str())
|
||||
!= Some(&entry.checksum)
|
||||
})
|
||||
.collect();
|
||||
|
||||
if changed.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let batches = self.pack_batches(changed);
|
||||
for batch in batches {
|
||||
self.upload_batch(batch, state)
|
||||
.await
|
||||
.context("team memory push: uploading batch")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Internals
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
/// Greedy bin-packing: pack entries into batches that each serialise to
|
||||
/// ≤ `MAX_PUT_BODY_BYTES`. Entries that individually exceed the limit go
|
||||
/// into singleton batches (server will reject them with 413, but that is
|
||||
/// the caller's problem).
|
||||
fn pack_batches(&self, entries: Vec<TeamMemoryEntry>) -> Vec<Vec<TeamMemoryEntry>> {
|
||||
let mut batches: Vec<Vec<TeamMemoryEntry>> = Vec::new();
|
||||
let mut current: Vec<TeamMemoryEntry> = Vec::new();
|
||||
let mut current_size: usize = 0;
|
||||
|
||||
for entry in entries {
|
||||
// Rough size estimate: key + content + JSON envelope overhead
|
||||
let entry_size = entry.key.len() + entry.content.len() + 100;
|
||||
|
||||
if entry_size > MAX_PUT_BODY_BYTES {
|
||||
// Oversized entry goes solo
|
||||
if !current.is_empty() {
|
||||
batches.push(std::mem::take(&mut current));
|
||||
current_size = 0;
|
||||
}
|
||||
batches.push(vec![entry]);
|
||||
continue;
|
||||
}
|
||||
|
||||
if current_size + entry_size > MAX_PUT_BODY_BYTES && !current.is_empty() {
|
||||
batches.push(std::mem::take(&mut current));
|
||||
current_size = 0;
|
||||
}
|
||||
|
||||
current_size += entry_size;
|
||||
current.push(entry);
|
||||
}
|
||||
|
||||
if !current.is_empty() {
|
||||
batches.push(current);
|
||||
}
|
||||
|
||||
batches
|
||||
}
|
||||
|
||||
async fn upload_batch(
|
||||
&self,
|
||||
batch: Vec<TeamMemoryEntry>,
|
||||
state: &mut SyncState,
|
||||
) -> Result<()> {
|
||||
let client = reqwest::Client::new();
|
||||
let url = format!(
|
||||
"{}/api/claude_code/team_memory?repo={}",
|
||||
self.api_base,
|
||||
urlencoding::encode(&self.repo),
|
||||
);
|
||||
|
||||
let body = serde_json::json!({ "entries": batch });
|
||||
|
||||
let mut req = client
|
||||
.put(&url)
|
||||
.bearer_auth(&self.token)
|
||||
.json(&body);
|
||||
|
||||
if let Some(etag) = &state.last_known_etag {
|
||||
req = req.header("If-Match", etag);
|
||||
}
|
||||
|
||||
let response = req
|
||||
.send()
|
||||
.await
|
||||
.context("team memory: PUT request failed")?;
|
||||
|
||||
let status = response.status().as_u16();
|
||||
|
||||
match status {
|
||||
200 | 201 | 204 => {
|
||||
if let Some(etag) = response
|
||||
.headers()
|
||||
.get("etag")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
{
|
||||
state.last_known_etag = Some(etag.to_string());
|
||||
}
|
||||
// Update local checksum map to reflect uploaded state
|
||||
for entry in &batch {
|
||||
state
|
||||
.server_checksums
|
||||
.insert(entry.key.clone(), entry.checksum.clone());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
412 => anyhow::bail!("Conflict (412 Precondition Failed): ETag mismatch, retry needed"),
|
||||
413 => anyhow::bail!("Payload too large (413)"),
|
||||
401 | 403 => anyhow::bail!("Authentication error ({})", status),
|
||||
_ => anyhow::bail!("Upload failed with status {}", status),
|
||||
}
|
||||
}
|
||||
|
||||
/// Recursively scan `team_dir` for `.md` files, returning entries sorted by key.
|
||||
async fn scan_local_files(&self) -> Result<Vec<TeamMemoryEntry>> {
|
||||
let mut entries = Vec::new();
|
||||
|
||||
if !self.team_dir.exists() {
|
||||
return Ok(entries);
|
||||
}
|
||||
|
||||
// Iterative DFS using an explicit stack to avoid deep recursion
|
||||
let mut stack = vec![self.team_dir.clone()];
|
||||
|
||||
while let Some(dir) = stack.pop() {
|
||||
let mut read_dir = tokio::fs::read_dir(&dir)
|
||||
.await
|
||||
.with_context(|| format!("read_dir {:?}", dir))?;
|
||||
|
||||
while let Some(entry) = read_dir.next_entry().await? {
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
stack.push(path);
|
||||
} else if path.extension().map(|e| e == "md").unwrap_or(false) {
|
||||
let content = tokio::fs::read_to_string(&path)
|
||||
.await
|
||||
.with_context(|| format!("reading {:?}", path))?;
|
||||
|
||||
if content.len() > MAX_FILE_SIZE_BYTES {
|
||||
continue; // Skip files that are too large
|
||||
}
|
||||
|
||||
let key = path
|
||||
.strip_prefix(&self.team_dir)
|
||||
.unwrap()
|
||||
.to_string_lossy()
|
||||
.replace('\\', "/");
|
||||
|
||||
let checksum = content_checksum(&content);
|
||||
entries.push(TeamMemoryEntry { key, content, checksum });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
entries.sort_by(|a, b| a.key.cmp(&b.key));
|
||||
Ok(entries)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Secret scanner
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A pattern matched during secret scanning.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct SecretMatch {
|
||||
/// Short label identifying the secret type, e.g. `"Anthropic API key"`.
|
||||
pub label: String,
|
||||
}
|
||||
|
||||
/// Scan `content` for common high-confidence secret patterns.
|
||||
///
|
||||
/// Returns one [`SecretMatch`] per distinct pattern that fired. The actual
|
||||
/// matched text is intentionally **not** returned to avoid logging credentials.
|
||||
pub fn scan_for_secrets(content: &str) -> Vec<SecretMatch> {
|
||||
// Each tuple: (regex source, human-readable label)
|
||||
// Patterns ordered by likelihood of appearing in dev-team memory content.
|
||||
const PATTERNS: &[(&str, &str)] = &[
|
||||
// Cloud providers
|
||||
(r"(?:A3T[A-Z0-9]|AKIA|ASIA|ABIA|ACCA)[A-Z2-7]{16}", "AWS access key"),
|
||||
(r"AIza[\w-]{35}", "GCP API key"),
|
||||
// AI APIs
|
||||
(r"sk-ant-api03-[a-zA-Z0-9_\-]{93}AA", "Anthropic API key"),
|
||||
(r"sk-ant-admin01-[a-zA-Z0-9_\-]{93}AA", "Anthropic admin API key"),
|
||||
(r"sk-[a-zA-Z0-9]{20}T3BlbkFJ[a-zA-Z0-9]{20}", "OpenAI API key"),
|
||||
// Version control
|
||||
(r"ghp_[0-9a-zA-Z]{36}", "GitHub personal access token"),
|
||||
(r"github_pat_\w{82}", "GitHub fine-grained PAT"),
|
||||
(r"(?:ghu|ghs)_[0-9a-zA-Z]{36}", "GitHub app token"),
|
||||
(r"gho_[0-9a-zA-Z]{36}", "GitHub OAuth token"),
|
||||
(r"glpat-[\w-]{20}", "GitLab PAT"),
|
||||
// Communication
|
||||
(r"xoxb-[0-9]{10,13}-[0-9]{10,13}[a-zA-Z0-9-]*", "Slack bot token"),
|
||||
// Crypto / private keys
|
||||
(r"-----BEGIN[ A-Z0-9_-]{0,100}PRIVATE KEY", "Private key"),
|
||||
// Payments
|
||||
(r"(?:sk|rk)_(?:test|live|prod)_[a-zA-Z0-9]{10,99}", "Stripe secret key"),
|
||||
// NPM
|
||||
(r"npm_[a-zA-Z0-9]{36}", "NPM access token"),
|
||||
];
|
||||
|
||||
let mut findings: Vec<SecretMatch> = Vec::new();
|
||||
|
||||
for (pattern, label) in PATTERNS {
|
||||
// Lazily compile; the fn is not hot enough to warrant a static cache here
|
||||
if let Ok(re) = regex::Regex::new(pattern) {
|
||||
if re.is_match(content) {
|
||||
findings.push(SecretMatch { label: label.to_string() });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
findings
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
// --- content_checksum ---
|
||||
|
||||
#[test]
|
||||
fn test_checksum_format() {
|
||||
let cs = content_checksum("hello");
|
||||
assert!(cs.starts_with("sha256:"), "checksum should start with sha256:");
|
||||
assert_eq!(cs.len(), "sha256:".len() + 64, "sha256 hex is 64 chars");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_checksum_deterministic() {
|
||||
assert_eq!(content_checksum("foo"), content_checksum("foo"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_checksum_distinct() {
|
||||
assert_ne!(content_checksum("foo"), content_checksum("bar"));
|
||||
}
|
||||
|
||||
// --- validate_memory_path ---
|
||||
|
||||
#[test]
|
||||
fn test_valid_paths_accepted() {
|
||||
let ok_paths = [
|
||||
"MEMORY.md",
|
||||
"sub/dir/file.md",
|
||||
"sub/dir/another-file.md",
|
||||
"a.md",
|
||||
];
|
||||
for p in &ok_paths {
|
||||
assert!(validate_memory_path(p).is_ok(), "should accept: {}", p);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_null_byte_rejected() {
|
||||
assert!(validate_memory_path("foo\0bar").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_url_encoded_dot_rejected() {
|
||||
assert!(validate_memory_path("%2e%2e/secret").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_url_encoded_slash_rejected() {
|
||||
assert!(validate_memory_path("foo%2Fbar").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backslash_rejected() {
|
||||
assert!(validate_memory_path("foo\\bar").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_absolute_unix_rejected() {
|
||||
assert!(validate_memory_path("/etc/passwd").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_absolute_windows_rejected() {
|
||||
assert!(validate_memory_path("C:foo").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dotdot_rejected() {
|
||||
assert!(validate_memory_path("../secret").is_err());
|
||||
assert!(validate_memory_path("a/../../secret").is_err());
|
||||
}
|
||||
|
||||
// --- pack_batches ---
|
||||
|
||||
fn make_sync() -> TeamMemorySync {
|
||||
TeamMemorySync::new(
|
||||
"https://example.com".to_string(),
|
||||
"owner/repo".to_string(),
|
||||
"token123".to_string(),
|
||||
PathBuf::from("/tmp/team"),
|
||||
)
|
||||
}
|
||||
|
||||
fn entry(key: &str, size: usize) -> TeamMemoryEntry {
|
||||
let content = "x".repeat(size);
|
||||
let checksum = content_checksum(&content);
|
||||
TeamMemoryEntry { key: key.to_string(), content, checksum }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pack_batches_empty() {
|
||||
let sync = make_sync();
|
||||
let batches = sync.pack_batches(vec![]);
|
||||
assert!(batches.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pack_batches_single_entry() {
|
||||
let sync = make_sync();
|
||||
let batches = sync.pack_batches(vec![entry("a.md", 100)]);
|
||||
assert_eq!(batches.len(), 1);
|
||||
assert_eq!(batches[0].len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pack_batches_oversized_solo() {
|
||||
let sync = make_sync();
|
||||
// Entry > MAX_PUT_BODY_BYTES → goes solo
|
||||
let big = entry("big.md", MAX_PUT_BODY_BYTES + 1);
|
||||
let small = entry("small.md", 100);
|
||||
let batches = sync.pack_batches(vec![big, small]);
|
||||
// big is solo, small may be in a separate batch
|
||||
assert!(batches.len() >= 2);
|
||||
assert_eq!(batches[0].len(), 1, "oversized entry is solo");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pack_batches_groups_small_entries() {
|
||||
let sync = make_sync();
|
||||
// Many small entries that each fit in one batch
|
||||
let entries: Vec<_> = (0..5).map(|i| entry(&format!("{i}.md"), 1024)).collect();
|
||||
let batches = sync.pack_batches(entries);
|
||||
// All 5 should fit in one batch (5 * ~1124 bytes << 200KB)
|
||||
assert_eq!(batches.len(), 1);
|
||||
assert_eq!(batches[0].len(), 5);
|
||||
}
|
||||
|
||||
// --- scan_for_secrets ---
|
||||
|
||||
#[test]
|
||||
fn test_no_secrets_clean() {
|
||||
let findings = scan_for_secrets("# Team notes\n\nSome markdown content here.");
|
||||
assert!(findings.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detects_github_pat() {
|
||||
let content = format!("token: ghp_{}", "A".repeat(36));
|
||||
let findings = scan_for_secrets(&content);
|
||||
assert!(
|
||||
findings.iter().any(|m| m.label.contains("GitHub")),
|
||||
"should detect GitHub PAT"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detects_aws_key() {
|
||||
let content = "key=AKIAIOSFODNN7EXAMPLE";
|
||||
let findings = scan_for_secrets(content);
|
||||
assert!(
|
||||
findings.iter().any(|m| m.label.contains("AWS")),
|
||||
"should detect AWS key"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detects_private_key() {
|
||||
let content = "-----BEGIN RSA PRIVATE KEY-----\nMIIE...\n";
|
||||
let findings = scan_for_secrets(content);
|
||||
assert!(
|
||||
findings.iter().any(|m| m.label.contains("Private key")),
|
||||
"should detect private key"
|
||||
);
|
||||
}
|
||||
|
||||
// --- scan_local_files (integration-style) ---
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_scan_local_files_empty_dir() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let sync = TeamMemorySync::new(
|
||||
"https://example.com".to_string(),
|
||||
"r".to_string(),
|
||||
"t".to_string(),
|
||||
tmp.path().to_path_buf(),
|
||||
);
|
||||
let entries = sync.scan_local_files().await.unwrap();
|
||||
assert!(entries.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_scan_local_files_finds_md() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
tokio::fs::write(tmp.path().join("MEMORY.md"), "# Memory").await.unwrap();
|
||||
tokio::fs::write(tmp.path().join("ignore.txt"), "not md").await.unwrap();
|
||||
|
||||
let sync = TeamMemorySync::new(
|
||||
"https://example.com".to_string(),
|
||||
"r".to_string(),
|
||||
"t".to_string(),
|
||||
tmp.path().to_path_buf(),
|
||||
);
|
||||
let entries = sync.scan_local_files().await.unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert_eq!(entries[0].key, "MEMORY.md");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_scan_local_files_sorted() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
tokio::fs::write(tmp.path().join("z.md"), "z").await.unwrap();
|
||||
tokio::fs::write(tmp.path().join("a.md"), "a").await.unwrap();
|
||||
tokio::fs::write(tmp.path().join("m.md"), "m").await.unwrap();
|
||||
|
||||
let sync = TeamMemorySync::new(
|
||||
"https://example.com".to_string(),
|
||||
"r".to_string(),
|
||||
"t".to_string(),
|
||||
tmp.path().to_path_buf(),
|
||||
);
|
||||
let entries = sync.scan_local_files().await.unwrap();
|
||||
let keys: Vec<_> = entries.iter().map(|e| e.key.as_str()).collect();
|
||||
assert_eq!(keys, vec!["a.md", "m.md", "z.md"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_scan_local_files_checksums_match() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let content = "# Hello world";
|
||||
tokio::fs::write(tmp.path().join("MEMORY.md"), content).await.unwrap();
|
||||
|
||||
let sync = TeamMemorySync::new(
|
||||
"https://example.com".to_string(),
|
||||
"r".to_string(),
|
||||
"t".to_string(),
|
||||
tmp.path().to_path_buf(),
|
||||
);
|
||||
let entries = sync.scan_local_files().await.unwrap();
|
||||
assert_eq!(entries[0].checksum, content_checksum(content));
|
||||
}
|
||||
}
|
||||
192
src-rust/crates/core/src/voice.rs
Normal file
192
src-rust/crates/core/src/voice.rs
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
//! Voice mode availability checks
|
||||
|
||||
use crate::oauth::OAuthTokens;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum VoiceAvailability {
|
||||
Available,
|
||||
/// Not authenticated via first-party OAuth
|
||||
RequiresOAuth,
|
||||
/// OAuth token missing required scopes
|
||||
MissingScopes {
|
||||
required: Vec<String>,
|
||||
have: Vec<String>,
|
||||
},
|
||||
/// Feature disabled by kill-switch environment variable
|
||||
Disabled,
|
||||
/// Feature flag not enabled in this build
|
||||
NotEnabled,
|
||||
}
|
||||
|
||||
/// Scopes required for voice mode to function
|
||||
const VOICE_REQUIRED_SCOPES: &[&str] = &["user:inference", "user:profile"];
|
||||
|
||||
/// Environment variable that disables voice mode when set (any value)
|
||||
const KILL_SWITCH_ENV: &str = "CLAUDE_CODE_VOICE_DISABLED";
|
||||
|
||||
/// Check whether voice mode is available given the current OAuth tokens.
|
||||
///
|
||||
/// Pass `None` when the user is not authenticated via OAuth (API-key-only auth).
|
||||
pub fn check_voice_availability(tokens: Option<&OAuthTokens>) -> VoiceAvailability {
|
||||
// Check kill switch first — always wins
|
||||
if std::env::var(KILL_SWITCH_ENV).is_ok() {
|
||||
return VoiceAvailability::Disabled;
|
||||
}
|
||||
|
||||
// Voice requires first-party OAuth; API key alone is not sufficient
|
||||
let tokens = match tokens {
|
||||
Some(t) => t,
|
||||
None => return VoiceAvailability::RequiresOAuth,
|
||||
};
|
||||
|
||||
// OAuthTokens stores scopes as Vec<String>
|
||||
let have_scopes: &[String] = &tokens.scopes;
|
||||
|
||||
let missing: Vec<String> = VOICE_REQUIRED_SCOPES
|
||||
.iter()
|
||||
.filter(|&&required| !have_scopes.iter().any(|h| h == required))
|
||||
.map(|s| s.to_string())
|
||||
.collect();
|
||||
|
||||
if !missing.is_empty() {
|
||||
return VoiceAvailability::MissingScopes {
|
||||
required: VOICE_REQUIRED_SCOPES
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect(),
|
||||
have: have_scopes.to_vec(),
|
||||
};
|
||||
}
|
||||
|
||||
VoiceAvailability::Available
|
||||
}
|
||||
|
||||
impl VoiceAvailability {
|
||||
/// Returns `true` when voice mode can be started.
|
||||
pub fn is_available(&self) -> bool {
|
||||
matches!(self, VoiceAvailability::Available)
|
||||
}
|
||||
|
||||
/// Returns a human-readable error message when voice is not available,
|
||||
/// or `None` when it is.
|
||||
pub fn error_message(&self) -> Option<String> {
|
||||
match self {
|
||||
VoiceAvailability::Available => None,
|
||||
VoiceAvailability::RequiresOAuth => Some(
|
||||
"Voice mode requires OAuth authentication. Run /login to authenticate.".to_string(),
|
||||
),
|
||||
VoiceAvailability::MissingScopes { required, have } => Some(format!(
|
||||
"Voice mode requires scopes: {}. Your token has: {}",
|
||||
required.join(", "),
|
||||
if have.is_empty() {
|
||||
"none".to_string()
|
||||
} else {
|
||||
have.join(", ")
|
||||
}
|
||||
)),
|
||||
VoiceAvailability::Disabled => Some("Voice mode is currently disabled.".to_string()),
|
||||
VoiceAvailability::NotEnabled => {
|
||||
Some("Voice mode is not enabled in this build.".to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn tokens_with_scopes(scopes: Vec<&str>) -> OAuthTokens {
|
||||
OAuthTokens {
|
||||
access_token: "test_token".to_string(),
|
||||
scopes: scopes.iter().map(|s| s.to_string()).collect(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_tokens_requires_oauth() {
|
||||
let result = check_voice_availability(None);
|
||||
assert_eq!(result, VoiceAvailability::RequiresOAuth);
|
||||
assert!(!result.is_available());
|
||||
assert!(result.error_message().is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_available_with_all_scopes() {
|
||||
let tokens = tokens_with_scopes(vec!["user:inference", "user:profile"]);
|
||||
let result = check_voice_availability(Some(&tokens));
|
||||
assert_eq!(result, VoiceAvailability::Available);
|
||||
assert!(result.is_available());
|
||||
assert!(result.error_message().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_missing_one_scope() {
|
||||
let tokens = tokens_with_scopes(vec!["user:inference"]);
|
||||
let result = check_voice_availability(Some(&tokens));
|
||||
assert!(matches!(result, VoiceAvailability::MissingScopes { .. }));
|
||||
assert!(!result.is_available());
|
||||
let msg = result.error_message().unwrap();
|
||||
assert!(msg.contains("user:profile"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_missing_all_scopes() {
|
||||
let tokens = tokens_with_scopes(vec!["org:create_api_key"]);
|
||||
let result = check_voice_availability(Some(&tokens));
|
||||
assert!(matches!(result, VoiceAvailability::MissingScopes { .. }));
|
||||
assert!(!result.is_available());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_scopes_missing() {
|
||||
let tokens = tokens_with_scopes(vec![]);
|
||||
let result = check_voice_availability(Some(&tokens));
|
||||
assert!(
|
||||
matches!(result, VoiceAvailability::MissingScopes { ref have, .. } if have.is_empty())
|
||||
);
|
||||
let msg = result.error_message().unwrap();
|
||||
assert!(msg.contains("none"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kill_switch_disables_voice() {
|
||||
// Temporarily set the kill-switch env var
|
||||
std::env::set_var(KILL_SWITCH_ENV, "1");
|
||||
let tokens = tokens_with_scopes(vec!["user:inference", "user:profile"]);
|
||||
let result = check_voice_availability(Some(&tokens));
|
||||
std::env::remove_var(KILL_SWITCH_ENV);
|
||||
assert_eq!(result, VoiceAvailability::Disabled);
|
||||
assert!(!result.is_available());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kill_switch_beats_no_auth() {
|
||||
std::env::set_var(KILL_SWITCH_ENV, "true");
|
||||
let result = check_voice_availability(None);
|
||||
std::env::remove_var(KILL_SWITCH_ENV);
|
||||
// Kill switch wins — returns Disabled, not RequiresOAuth
|
||||
assert_eq!(result, VoiceAvailability::Disabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_not_enabled_error_message() {
|
||||
let v = VoiceAvailability::NotEnabled;
|
||||
assert!(!v.is_available());
|
||||
assert!(v.error_message().unwrap().contains("not enabled"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extra_scopes_still_available() {
|
||||
// Having more scopes than required is fine
|
||||
let tokens = tokens_with_scopes(vec![
|
||||
"user:inference",
|
||||
"user:profile",
|
||||
"org:create_api_key",
|
||||
"user:file_upload",
|
||||
]);
|
||||
let result = check_voice_availability(Some(&tokens));
|
||||
assert_eq!(result, VoiceAvailability::Available);
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue