Skip to content

Commit

Permalink
perf: avoid cloning sumcheck terms
Browse files Browse the repository at this point in the history
  • Loading branch information
JayWhite2357 committed Dec 12, 2024
1 parent 52c3f90 commit 80216d5
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 15 deletions.
17 changes: 8 additions & 9 deletions crates/proof-of-sql/src/base/polynomial/multilinear_extension.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::base::{database::Column, if_rayon, scalar::Scalar, slice_ops};
use alloc::{rc::Rc, vec::Vec};
use alloc::vec::Vec;
use core::{ffi::c_void, fmt::Debug};
use num_traits::Zero;
#[cfg(feature = "rayon")]
Expand All @@ -15,7 +15,7 @@ pub trait MultilinearExtension<S: Scalar>: Debug {
fn mul_add(&self, res: &mut [S], multiplier: &S);

/// convert the MLE to a form that can be used in sumcheck
fn to_sumcheck_term(&self, num_vars: usize) -> Rc<Vec<S>>;
fn to_sumcheck_term(&self, num_vars: usize) -> Vec<S>;

/// pointer to identify the slice forming the MLE
fn id(&self) -> (*const c_void, usize);
Expand All @@ -42,18 +42,17 @@ where
slice_ops::mul_add_assign(res, *multiplier, &slice_ops::slice_cast(self));
}

fn to_sumcheck_term(&self, num_vars: usize) -> Rc<Vec<S>> {
fn to_sumcheck_term(&self, num_vars: usize) -> Vec<S> {
let values = self;
let n = 1 << num_vars;
assert!(n >= values.len());
let scalars = if_rayon!(values.par_iter(), values.iter())
if_rayon!(values.par_iter(), values.iter())
.map(Into::into)
.chain(if_rayon!(
rayon::iter::repeatn(Zero::zero(), n - values.len()),
itertools::repeat_n(Zero::zero(), n - values.len())
))
.collect();
Rc::new(scalars)
.collect()
}

fn id(&self) -> (*const c_void, usize) {
Expand All @@ -72,7 +71,7 @@ macro_rules! slice_like_mle_impl {
(&self[..]).mul_add(res, multiplier)
}

fn to_sumcheck_term(&self, num_vars: usize) -> Rc<Vec<S>> {
fn to_sumcheck_term(&self, num_vars: usize) -> Vec<S> {
(&self[..]).to_sumcheck_term(num_vars)
}

Expand Down Expand Up @@ -125,7 +124,7 @@ impl<S: Scalar> MultilinearExtension<S> for &Column<'_, S> {
}
}

fn to_sumcheck_term(&self, num_vars: usize) -> Rc<Vec<S>> {
fn to_sumcheck_term(&self, num_vars: usize) -> Vec<S> {
match self {
Column::Boolean(c) => c.to_sumcheck_term(num_vars),
Column::Scalar(c) | Column::VarChar((_, c)) | Column::Decimal75(_, _, c) => {
Expand Down Expand Up @@ -163,7 +162,7 @@ impl<S: Scalar> MultilinearExtension<S> for Column<'_, S> {
(&self).mul_add(res, multiplier);
}

fn to_sumcheck_term(&self, num_vars: usize) -> Rc<Vec<S>> {
fn to_sumcheck_term(&self, num_vars: usize) -> Vec<S> {
(&self).to_sumcheck_term(num_vars)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl<S: Scalar> CompositePolynomialBuilder<S> {
fr_multiplicands_degree1: vec![Zero::zero(); fr.len()],
fr_multiplicands_rest: vec![],
zerosum_multiplicands: vec![],
fr: fr.to_sumcheck_term(num_sumcheck_variables),
fr: fr.to_sumcheck_term(num_sumcheck_variables).into(),
mles: IndexMap::default(),
}
}
Expand Down Expand Up @@ -89,8 +89,8 @@ impl<S: Scalar> CompositePolynomialBuilder<S> {
deduplicated_terms.push(cached_term.clone());
} else {
let new_term = term.to_sumcheck_term(self.num_sumcheck_variables);
self.mles.insert(id, new_term.clone());
deduplicated_terms.push(new_term);
self.mles.insert(id, new_term.clone().into());
deduplicated_terms.push(new_term.into());
}
}
deduplicated_terms
Expand All @@ -103,7 +103,9 @@ impl<S: Scalar> CompositePolynomialBuilder<S> {
res.add_product(
[
self.fr.clone(),
(&self.fr_multiplicands_degree1).to_sumcheck_term(self.num_sumcheck_variables),
(&self.fr_multiplicands_degree1)
.to_sumcheck_term(self.num_sumcheck_variables)
.into(),
],
One::one(),
);
Expand Down
4 changes: 2 additions & 2 deletions crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,11 @@ impl<'a, S: Scalar> FlattenedMLEBuilder<'a, S> {
fn flattened_ml_extensions(self) -> Vec<Vec<S>> {
self.entrywise_multipliers
.into_iter()
.map(|mle| (&mle).to_sumcheck_term(self.num_vars).as_ref().clone())
.map(|mle| (&mle).to_sumcheck_term(self.num_vars))
.chain(
self.all_ml_extensions
.iter()
.map(|mle| mle.to_sumcheck_term(self.num_vars).as_ref().clone()),
.map(|mle| mle.to_sumcheck_term(self.num_vars)),
)
.collect()
}
Expand Down

0 comments on commit 80216d5

Please sign in to comment.