失敗する可能性があるfoldをtry_foldで楽に書く

std::iter::Iterator::fold という便利なメソッドがあります:

let vec = vec![1, 2, 3];

let sum = vec.iter().fold(0, |acc, x| acc + x);

assert_eq!(sum, 6);

いま、 fold に渡すクロージャ内の処理で失敗する可能性があり、 ? 演算子を使いたくなったとします。どのように書けば良いでしょうか?

fold の型は

fn fold<B, F>(self, init: B, f: F) -> B
where
    F: FnMut(B, Self::Item) -> B,

であり、 ? 演算子を使うために f の戻り値型(つまり B)は Result 型である必要があるので、初期値 initResult 型になり、例えば以下のようなコードを書くことになります:

fn parse_sum(vec: Vec<String>) -> Result<u64, std::num::ParseIntError> {
    vec.iter().fold(Ok(0), |wrapped_acc, string| {
        wrapped_acc.and_then(|acc| Ok(acc + string.parse::<u64>()?))
    })
}

fn main() {
    let vec = vec!["1".into(), "2".into(), "3".into()];

    let res = parse_sum(vec);

    assert_eq!(res.unwrap(), 6);
}

ところで ? 演算子はエラー時にその直近の関数(またはクロージャからしか脱出せず、 上記のコードで parse 時にエラーが発生したとしても foldイテレーションは止まりません:

fn parse_sum(vec: Vec<String>)V -> Result<u64, std::num::ParseIntError> {
    vec.iter()
        .inspect(|string| println!("passing {} to fold", string))
        .fold(Ok(0), |wrapped_acc, string| {
            wrapped_acc.and_then(|acc| Ok(acc + string.parse::<u64>()?))
        })
}

fn main() {
    let vec = vec!["1".into(), "2".into(), "hello".into(), "3".into()];

    let res = parse_sum(vec);

    assert!(res.is_err());
}
passing 1 to fold
passing 2 to fold
passing hello to fold
passing 3 to fold

できればエラーが発生したらその時点で早期にエラーを返してほしいですね。

そんなときに try_fold が便利に使えます。

try_fold の型は以下のようになっています:

fn try_fold<B, F, R>(&mut self, init: B, f: F) -> R
where
    F: FnMut(B, Self::Item) -> R,
    R: Try<Ok = B>,

さきほどの例は、初期値を Ok でラップしたり and_then を使ったりすることなく以下のように書き直せます:

fn parse_sum(vec: Vec<String>) -> Result<u64, std::num::ParseIntError> {
    vec.iter()
        .try_fold(0, |acc, string| Ok(acc + string.parse::<u64>()?))
}

fn main() {
    let vec = vec!["1".into(), "2".into(), "3".into()];

    let res = parse_sum(vec);

    assert_eq!(res.unwrap(), 6);
}

さらに try_foldクロージャがエラーを返した時点でただちにそのエラーを返します:

fn parse_sum(vec: Vec<String>) -> Result<u64, std::num::ParseIntError> {
    vec.iter()
        .inspect(|string| println!("passing {} to try_fold", string))
        .try_fold(0, |acc, string| Ok(acc + string.parse::<u64>()?))
}

fn main() {
    let vec = vec!["1".into(), "2".into(), "hello".into(), "3".into()];

    let res = parse_sum(vec);

    assert!(res.is_err());
}
passing 1 to try_fold
passing 2 to try_fold
passing hello to try_fold

便利ですね。

普段 fold だけに目が行きがちなので書いてみました。