From 80216d5931cb07b59bc781f469f1b36609ff1104 Mon Sep 17 00:00:00 2001 From: Jay White Date: Wed, 11 Dec 2024 13:01:36 -0500 Subject: [PATCH] perf: avoid cloning sumcheck terms --- .../base/polynomial/multilinear_extension.rs | 17 ++++++++--------- .../sql/proof/composite_polynomial_builder.rs | 10 ++++++---- .../src/sql/proof/make_sumcheck_state.rs | 4 ++-- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/crates/proof-of-sql/src/base/polynomial/multilinear_extension.rs b/crates/proof-of-sql/src/base/polynomial/multilinear_extension.rs index a5fe4481c..ca361dae4 100644 --- a/crates/proof-of-sql/src/base/polynomial/multilinear_extension.rs +++ b/crates/proof-of-sql/src/base/polynomial/multilinear_extension.rs @@ -1,5 +1,5 @@ use crate::base::{database::Column, if_rayon, scalar::Scalar, slice_ops}; -use alloc::{rc::Rc, vec::Vec}; +use alloc::vec::Vec; use core::{ffi::c_void, fmt::Debug}; use num_traits::Zero; #[cfg(feature = "rayon")] @@ -15,7 +15,7 @@ pub trait MultilinearExtension: Debug { fn mul_add(&self, res: &mut [S], multiplier: &S); /// convert the MLE to a form that can be used in sumcheck - fn to_sumcheck_term(&self, num_vars: usize) -> Rc>; + fn to_sumcheck_term(&self, num_vars: usize) -> Vec; /// pointer to identify the slice forming the MLE fn id(&self) -> (*const c_void, usize); @@ -42,18 +42,17 @@ where slice_ops::mul_add_assign(res, *multiplier, &slice_ops::slice_cast(self)); } - fn to_sumcheck_term(&self, num_vars: usize) -> Rc> { + fn to_sumcheck_term(&self, num_vars: usize) -> Vec { let values = self; let n = 1 << num_vars; assert!(n >= values.len()); - let scalars = if_rayon!(values.par_iter(), values.iter()) + if_rayon!(values.par_iter(), values.iter()) .map(Into::into) .chain(if_rayon!( rayon::iter::repeatn(Zero::zero(), n - values.len()), itertools::repeat_n(Zero::zero(), n - values.len()) )) - .collect(); - Rc::new(scalars) + .collect() } fn id(&self) -> (*const c_void, usize) { @@ -72,7 +71,7 @@ macro_rules! slice_like_mle_impl { (&self[..]).mul_add(res, multiplier) } - fn to_sumcheck_term(&self, num_vars: usize) -> Rc> { + fn to_sumcheck_term(&self, num_vars: usize) -> Vec { (&self[..]).to_sumcheck_term(num_vars) } @@ -125,7 +124,7 @@ impl MultilinearExtension for &Column<'_, S> { } } - fn to_sumcheck_term(&self, num_vars: usize) -> Rc> { + fn to_sumcheck_term(&self, num_vars: usize) -> Vec { match self { Column::Boolean(c) => c.to_sumcheck_term(num_vars), Column::Scalar(c) | Column::VarChar((_, c)) | Column::Decimal75(_, _, c) => { @@ -163,7 +162,7 @@ impl MultilinearExtension for Column<'_, S> { (&self).mul_add(res, multiplier); } - fn to_sumcheck_term(&self, num_vars: usize) -> Rc> { + fn to_sumcheck_term(&self, num_vars: usize) -> Vec { (&self).to_sumcheck_term(num_vars) } diff --git a/crates/proof-of-sql/src/sql/proof/composite_polynomial_builder.rs b/crates/proof-of-sql/src/sql/proof/composite_polynomial_builder.rs index 62346bb1e..e554cbe85 100644 --- a/crates/proof-of-sql/src/sql/proof/composite_polynomial_builder.rs +++ b/crates/proof-of-sql/src/sql/proof/composite_polynomial_builder.rs @@ -32,7 +32,7 @@ impl CompositePolynomialBuilder { fr_multiplicands_degree1: vec![Zero::zero(); fr.len()], fr_multiplicands_rest: vec![], zerosum_multiplicands: vec![], - fr: fr.to_sumcheck_term(num_sumcheck_variables), + fr: fr.to_sumcheck_term(num_sumcheck_variables).into(), mles: IndexMap::default(), } } @@ -89,8 +89,8 @@ impl CompositePolynomialBuilder { deduplicated_terms.push(cached_term.clone()); } else { let new_term = term.to_sumcheck_term(self.num_sumcheck_variables); - self.mles.insert(id, new_term.clone()); - deduplicated_terms.push(new_term); + self.mles.insert(id, new_term.clone().into()); + deduplicated_terms.push(new_term.into()); } } deduplicated_terms @@ -103,7 +103,9 @@ impl CompositePolynomialBuilder { res.add_product( [ self.fr.clone(), - (&self.fr_multiplicands_degree1).to_sumcheck_term(self.num_sumcheck_variables), + (&self.fr_multiplicands_degree1) + .to_sumcheck_term(self.num_sumcheck_variables) + .into(), ], One::one(), ); diff --git a/crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs b/crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs index 3f494f2bb..58dc90c22 100644 --- a/crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs +++ b/crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs @@ -98,11 +98,11 @@ impl<'a, S: Scalar> FlattenedMLEBuilder<'a, S> { fn flattened_ml_extensions(self) -> Vec> { self.entrywise_multipliers .into_iter() - .map(|mle| (&mle).to_sumcheck_term(self.num_vars).as_ref().clone()) + .map(|mle| (&mle).to_sumcheck_term(self.num_vars)) .chain( self.all_ml_extensions .iter() - .map(|mle| mle.to_sumcheck_term(self.num_vars).as_ref().clone()), + .map(|mle| mle.to_sumcheck_term(self.num_vars)), ) .collect() }