-
Notifications
You must be signed in to change notification settings - Fork 184
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add yul implementation of lagrange basis evaluation
- Loading branch information
1 parent
a3da1a3
commit 83714f5
Showing
1 changed file
with
74 additions
and
0 deletions.
There are no files selected for viewing
74 changes: 74 additions & 0 deletions
74
crates/proof-of-sql/sol_src/base/LagrangeBasisEvaluation.sol
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
// SPDX-License-Identifier: UNLICENSED | ||
pragma solidity ^0.8.13; | ||
|
||
contract LagrangeBasisEvaluation { | ||
function computeTruncatedLagrangeBasisSum(uint256 length0, bytes memory point0, uint256 numVars0, uint256 modulus0) | ||
public | ||
pure | ||
returns (uint256 result0) | ||
{ | ||
// solhint-disable-next-line no-inline-assembly | ||
assembly { | ||
function compute_truncated_lagrange_basis_sum(length, point, num_vars, modulus) -> result { | ||
let ONE := add(modulus, 1) | ||
// result := 0 // implicitly set by the EVM | ||
|
||
// Invariant that holds within the for loop: | ||
// 0 <= result <= modulus + 1 | ||
// This invariant reduces modulus operations. | ||
for {} num_vars {} { | ||
switch and(length, 1) | ||
case 0 { result := mulmod(result, sub(ONE, mod(mload(point), modulus)), modulus) } | ||
default { result := sub(ONE, mulmod(sub(ONE, result), mload(point), modulus)) } | ||
num_vars := sub(num_vars, 1) | ||
length := shr(1, length) | ||
point := add(point, 32) | ||
} | ||
switch length | ||
case 0 { result := mod(result, modulus) } | ||
default { result := 1 } | ||
} | ||
result0 := compute_truncated_lagrange_basis_sum(length0, add(point0, 32), numVars0, modulus0) | ||
} | ||
} | ||
|
||
uint256 private constant TEST_MODULUS = 10007; | ||
|
||
function testComputeTruncatedLagrangeBasisSumGivesCorrectValuesWith0Variables() public pure { | ||
bytes memory point = hex""; | ||
assert(computeTruncatedLagrangeBasisSum(1, point, 0, TEST_MODULUS) == 1); | ||
assert(computeTruncatedLagrangeBasisSum(0, point, 0, TEST_MODULUS) == 0); | ||
} | ||
|
||
function testComputeTruncatedLagrangeBasisSumGivesCorrectValuesWith1Variables() public pure { | ||
bytes memory point = hex"0000000000000000" hex"0000000000000000" hex"0000000000000000" hex"0000000000000002"; | ||
assert(computeTruncatedLagrangeBasisSum(2, point, 1, TEST_MODULUS) == 1); | ||
assert(computeTruncatedLagrangeBasisSum(1, point, 1, TEST_MODULUS) == TEST_MODULUS - 1); | ||
assert(computeTruncatedLagrangeBasisSum(0, point, 1, TEST_MODULUS) == 0); | ||
} | ||
|
||
function testComputeTruncatedLagrangeBasisSumGivesCorrectValuesWith2Variables() public pure { | ||
bytes memory point = hex"0000000000000000" hex"0000000000000000" hex"0000000000000000" hex"0000000000000002" | ||
hex"0000000000000000" hex"0000000000000000" hex"0000000000000000" hex"0000000000000005"; | ||
assert(computeTruncatedLagrangeBasisSum(4, point, 2, TEST_MODULUS) == 1); | ||
assert(computeTruncatedLagrangeBasisSum(3, point, 2, TEST_MODULUS) == TEST_MODULUS - 9); | ||
assert(computeTruncatedLagrangeBasisSum(2, point, 2, TEST_MODULUS) == TEST_MODULUS - 4); | ||
assert(computeTruncatedLagrangeBasisSum(1, point, 2, TEST_MODULUS) == 4); | ||
assert(computeTruncatedLagrangeBasisSum(0, point, 2, TEST_MODULUS) == 0); | ||
} | ||
|
||
function testComputeTruncatedLagrangeBasisSumGivesCorrectValuesWith3Variables() public pure { | ||
bytes memory point = hex"0000000000000000" hex"0000000000000000" hex"0000000000000000" hex"0000000000000002" | ||
hex"0000000000000000" hex"0000000000000000" hex"0000000000000000" hex"0000000000000005" | ||
hex"0000000000000000" hex"0000000000000000" hex"0000000000000000" hex"0000000000000007"; | ||
assert(computeTruncatedLagrangeBasisSum(8, point, 3, TEST_MODULUS) == 1); | ||
assert(computeTruncatedLagrangeBasisSum(7, point, 3, TEST_MODULUS) == TEST_MODULUS - 69); | ||
assert(computeTruncatedLagrangeBasisSum(6, point, 3, TEST_MODULUS) == TEST_MODULUS - 34); | ||
assert(computeTruncatedLagrangeBasisSum(5, point, 3, TEST_MODULUS) == 22); | ||
assert(computeTruncatedLagrangeBasisSum(4, point, 3, TEST_MODULUS) == TEST_MODULUS - 6); | ||
assert(computeTruncatedLagrangeBasisSum(3, point, 3, TEST_MODULUS) == 54); | ||
assert(computeTruncatedLagrangeBasisSum(2, point, 3, TEST_MODULUS) == 24); | ||
assert(computeTruncatedLagrangeBasisSum(1, point, 3, TEST_MODULUS) == TEST_MODULUS - 24); | ||
assert(computeTruncatedLagrangeBasisSum(0, point, 3, TEST_MODULUS) == 0); | ||
} | ||
} |