first pass at prefix rules

This commit is contained in:
kevin zhao
2025-11-10 10:38:08 -08:00
parent 6c384eb9c6
commit 773177ec8b
13 changed files with 664 additions and 0 deletions

View File

@@ -0,0 +1,9 @@
use crate::error::Error;
use crate::error::Result;
pub fn tokenize_command(raw: &str) -> Result<Vec<String>> {
shlex::split(raw).ok_or_else(|| Error::TokenizationFailed {
example: raw.to_string(),
reason: "invalid shell tokens".to_string(),
})
}

View File

@@ -0,0 +1,33 @@
use serde::Deserialize;
use serde::Serialize;
use crate::error::Error;
use crate::error::Result;
#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Decision {
Allow,
Prompt,
Forbidden,
}
impl Decision {
pub fn parse(raw: &str) -> Result<Self> {
match raw {
"allow" => Ok(Self::Allow),
"prompt" => Ok(Self::Prompt),
"forbidden" => Ok(Self::Forbidden),
other => Err(Error::InvalidDecision(other.to_string())),
}
}
/// Returns true if `self` is stricter (less permissive) than `other`.
pub fn is_stricter_than(self, other: Self) -> bool {
matches!(
(self, other),
(Decision::Forbidden, Decision::Prompt | Decision::Allow)
| (Decision::Prompt, Decision::Allow)
)
}
}

View File

@@ -0,0 +1,36 @@
prefix_rule(
id = "git_status",
pattern = ["git", "status"],
match = [
"git status",
"git status -- path/to/file",
],
not_match = [
"git statusx",
"git reset --hard",
],
)
prefix_rule(
id = "npm_install",
pattern = ["npm", ["i", "install"]],
decision = "prompt",
match = [
"npm i",
"npm install",
"npm install lodash",
],
not_match = [
"npmx install",
"npm outdated",
],
)
prefix_rule(
id = "git_reset_hard",
pattern = ["git", "reset", "--hard"],
decision = "forbidden",
match = [
"git reset --hard",
],
)

View File

@@ -0,0 +1,19 @@
use thiserror::Error;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, Error)]
pub enum Error {
#[error("invalid decision: {0}")]
InvalidDecision(String),
#[error("invalid pattern element: {0}")]
InvalidPattern(String),
#[error("failed to tokenize example `{example}`: {reason}")]
TokenizationFailed { example: String, reason: String },
#[error("expected example to match rule `{rule_id}`: {example}")]
ExampleDidNotMatch { rule_id: String, example: String },
#[error("expected example to not match rule `{rule_id}`: {example}")]
ExampleDidMatch { rule_id: String, example: String },
#[error("starlark error: {0}")]
Starlark(String),
}

View File

@@ -0,0 +1,23 @@
pub mod command;
pub mod decision;
pub mod error;
pub mod parser;
pub mod policy;
pub mod rule;
pub use command::tokenize_command;
pub use decision::Decision;
pub use error::Error;
pub use error::Result;
pub use parser::PolicyParser;
pub use policy::Evaluation;
pub use policy::Policy;
pub use rule::Rule;
pub use rule::RuleMatch;
/// Load the default bundled policy.
pub fn load_default_policy() -> Result<Policy> {
let policy_src = include_str!("default.policy");
let parser = PolicyParser::new("default.policy", policy_src);
parser.parse()
}

View File

@@ -0,0 +1,85 @@
use std::fs;
use std::path::Path;
use anyhow::Context;
use anyhow::Result;
use anyhow::bail;
use codex_execpolicy2::PolicyParser;
use codex_execpolicy2::load_default_policy;
use codex_execpolicy2::tokenize_command;
fn main() -> Result<()> {
let mut args = std::env::args().skip(1);
let mut policy_path: Option<String> = None;
while let Some(arg) = args.next() {
if arg == "--policy" || arg == "-p" {
let path = args
.next()
.context("expected a policy path after --policy/-p")?;
policy_path = Some(path);
continue;
}
// First non-flag argument is the subcommand.
let subcommand = arg;
return run_subcommand(subcommand, policy_path, args.collect());
}
print_usage();
bail!("missing subcommand")
}
fn run_subcommand(
subcommand: String,
policy_path: Option<String>,
args: Vec<String>,
) -> Result<()> {
match subcommand.as_str() {
"check" => cmd_check(policy_path, args),
_ => {
print_usage();
bail!("unknown subcommand: {subcommand}")
}
}
}
fn cmd_check(policy_path: Option<String>, args: Vec<String>) -> Result<()> {
if args.is_empty() {
bail!("usage: codex-execpolicy2 check <command tokens...|\"command string\">");
}
let policy = load_policy(policy_path)?;
let tokens = if args.len() == 1 {
tokenize_command(&args[0])?
} else {
args
};
match policy.evaluate(&tokens) {
Some(eval) => {
let json = serde_json::to_string_pretty(&eval)?;
println!("{json}");
}
None => {
println!("no match");
}
}
Ok(())
}
fn load_policy(policy_path: Option<String>) -> Result<codex_execpolicy2::Policy> {
if let Some(path) = policy_path {
let content = fs::read_to_string(&path)
.with_context(|| format!("failed to read policy at {}", Path::new(&path).display()))?;
let parser = PolicyParser::new(&path, &content);
return Ok(parser.parse()?);
}
Ok(load_default_policy()?)
}
fn print_usage() {
eprintln!(
"usage:
codex-execpolicy2 [--policy path] check <command tokens...|\"command string\">"
);
}

View File

@@ -0,0 +1,205 @@
use std::cell::RefCell;
use starlark::any::ProvidesStaticType;
use starlark::environment::GlobalsBuilder;
use starlark::environment::Module;
use starlark::eval::Evaluator;
use starlark::starlark_module;
use starlark::syntax::AstModule;
use starlark::syntax::Dialect;
use starlark::values::Value;
use starlark::values::list::ListRef;
use starlark::values::list::UnpackList;
use starlark::values::none::NoneType;
use crate::command::tokenize_command;
use crate::decision::Decision;
use crate::error::Error;
use crate::error::Result;
use crate::rule::Rule;
pub struct PolicyParser {
policy_source: String,
unparsed_policy: String,
}
impl PolicyParser {
pub fn new(policy_source: &str, unparsed_policy: &str) -> Self {
Self {
policy_source: policy_source.to_string(),
unparsed_policy: unparsed_policy.to_string(),
}
}
pub fn parse(&self) -> Result<crate::policy::Policy> {
let mut dialect = Dialect::Extended.clone();
dialect.enable_f_strings = true;
let ast = AstModule::parse(&self.policy_source, self.unparsed_policy.clone(), &dialect)
.map_err(|e| Error::Starlark(e.to_string()))?;
let globals = GlobalsBuilder::standard().with(policy_builtins).build();
let module = Module::new();
let builder = PolicyBuilder::new();
{
let mut eval = Evaluator::new(&module);
eval.extra = Some(&builder);
eval.eval_module(ast, &globals)
.map_err(|e| Error::Starlark(e.to_string()))?;
}
Ok(builder.build())
}
}
#[derive(Debug, ProvidesStaticType)]
struct PolicyBuilder {
rules: RefCell<Vec<Rule>>,
next_auto_id: RefCell<i64>,
}
impl PolicyBuilder {
fn new() -> Self {
Self {
rules: RefCell::new(Vec::new()),
next_auto_id: RefCell::new(0),
}
}
fn alloc_id(&self) -> String {
let mut next = self.next_auto_id.borrow_mut();
let id = *next;
*next += 1;
format!("rule_{id}")
}
fn add_rule(&self, rule: Rule) {
self.rules.borrow_mut().push(rule);
}
fn build(&self) -> crate::policy::Policy {
crate::policy::Policy::new(self.rules.borrow().clone())
}
}
#[derive(Debug)]
enum PatternPart {
Single(String),
Alts(Vec<String>),
}
fn expand_pattern(parts: &[PatternPart]) -> Vec<Vec<String>> {
let mut acc: Vec<Vec<String>> = vec![Vec::new()];
for part in parts {
let alts: Vec<String> = match part {
PatternPart::Single(s) => vec![s.clone()],
PatternPart::Alts(v) => v.clone(),
};
let mut next = Vec::new();
for prefix in &acc {
for alt in &alts {
let mut combined = prefix.clone();
combined.push(alt.clone());
next.push(combined);
}
}
acc = next;
}
acc
}
fn parse_pattern<'v>(pattern: UnpackList<Value<'v>>) -> Result<Vec<Vec<String>>> {
let mut parts = Vec::new();
for item in pattern.items {
if let Some(s) = item.unpack_str() {
parts.push(PatternPart::Single(s.to_string()));
continue;
}
let mut alts = Vec::new();
if let Some(list) = ListRef::from_value(item) {
for value in list.content() {
let s = value.unpack_str().ok_or_else(|| {
Error::InvalidPattern("pattern alternative must be a string".to_string())
})?;
alts.push(s.to_string());
}
} else {
return Err(Error::InvalidPattern(
"pattern element must be a string or list of strings".to_string(),
));
}
if alts.is_empty() {
return Err(Error::InvalidPattern(
"pattern alternatives cannot be empty".to_string(),
));
}
parts.push(PatternPart::Alts(alts));
}
Ok(expand_pattern(&parts))
}
#[starlark_module]
fn policy_builtins(builder: &mut GlobalsBuilder) {
fn prefix_rule<'v>(
pattern: UnpackList<Value<'v>>,
decision: Option<&'v str>,
r#match: Option<UnpackList<&'v str>>,
not_match: Option<UnpackList<&'v str>>,
id: Option<&'v str>,
eval: &mut Evaluator<'v, '_, '_>,
) -> anyhow::Result<NoneType> {
let decision = match decision {
Some(raw) => Decision::parse(raw)?,
None => Decision::Allow,
};
let prefixes = parse_pattern(pattern)?;
let positive_examples: Vec<Vec<String>> = r#match
.map(|examples| {
examples
.items
.into_iter()
.map(tokenize_command)
.collect::<Result<Vec<_>>>()
})
.transpose()?
.unwrap_or_default();
let negative_examples: Vec<Vec<String>> = not_match
.map(|examples| {
examples
.items
.into_iter()
.map(tokenize_command)
.collect::<Result<Vec<_>>>()
})
.transpose()?
.unwrap_or_default();
let id = id.map(std::string::ToString::to_string).unwrap_or_else(|| {
#[expect(clippy::unwrap_used)]
let builder = eval
.extra
.as_ref()
.unwrap()
.downcast_ref::<PolicyBuilder>()
.unwrap();
builder.alloc_id()
});
let rule = Rule {
id: id.clone(),
prefixes,
decision,
};
rule.validate_examples(&positive_examples, &negative_examples)?;
#[expect(clippy::unwrap_used)]
let builder = eval
.extra
.as_ref()
.unwrap()
.downcast_ref::<PolicyBuilder>()
.unwrap();
builder.add_rule(rule);
Ok(NoneType)
}
}

View File

@@ -0,0 +1,59 @@
use crate::decision::Decision;
use crate::rule::Rule;
use crate::rule::RuleMatch;
use serde::Deserialize;
use serde::Serialize;
#[derive(Clone, Debug)]
pub struct Policy {
rules: Vec<Rule>,
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct Evaluation {
pub rule_id: String,
pub decision: Decision,
pub matched_prefix: Vec<String>,
pub remainder: Vec<String>,
}
impl From<RuleMatch> for Evaluation {
fn from(value: RuleMatch) -> Self {
Self {
rule_id: value.rule_id,
decision: value.decision,
matched_prefix: value.matched_prefix,
remainder: value.remainder,
}
}
}
impl Policy {
pub fn new(rules: Vec<Rule>) -> Self {
Self { rules }
}
pub fn rules(&self) -> &[Rule] {
&self.rules
}
pub fn evaluate(&self, cmd: &[String]) -> Option<Evaluation> {
let mut best: Option<Evaluation> = None;
for rule in &self.rules {
if let Some(matched) = rule.matches(cmd) {
let eval = Evaluation::from(matched);
best = match best {
None => Some(eval),
Some(current) => {
if eval.decision.is_stricter_than(current.decision) {
Some(eval)
} else {
Some(current)
}
}
};
}
}
best
}
}

View File

@@ -0,0 +1,66 @@
use crate::decision::Decision;
use crate::error::Error;
use crate::error::Result;
#[derive(Clone, Debug)]
pub struct Rule {
pub id: String,
pub prefixes: Vec<Vec<String>>,
pub decision: Decision,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct RuleMatch {
pub rule_id: String,
pub matched_prefix: Vec<String>,
pub remainder: Vec<String>,
pub decision: Decision,
}
impl Rule {
pub fn matches(&self, cmd: &[String]) -> Option<RuleMatch> {
for prefix in &self.prefixes {
if prefix.len() > cmd.len() {
continue;
}
if cmd
.iter()
.zip(prefix)
.all(|(cmd_tok, prefix_tok)| cmd_tok == prefix_tok)
{
let remainder = cmd[prefix.len()..].to_vec();
return Some(RuleMatch {
rule_id: self.id.clone(),
matched_prefix: prefix.clone(),
remainder,
decision: self.decision,
});
}
}
None
}
pub fn validate_examples(
&self,
positive: &[Vec<String>],
negative: &[Vec<String>],
) -> Result<()> {
for example in positive {
if self.matches(example).is_none() {
return Err(Error::ExampleDidNotMatch {
rule_id: self.id.clone(),
example: example.join(" "),
});
}
}
for example in negative {
if self.matches(example).is_some() {
return Err(Error::ExampleDidMatch {
rule_id: self.id.clone(),
example: example.join(" "),
});
}
}
Ok(())
}
}