Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support constrained decoding #1038

Open
wants to merge 86 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
bf38535
add llguidance based logits processor
Taka152 Oct 31, 2024
c151d52
add unit test
Taka152 Oct 31, 2024
9d5a8a0
constrained decoding fixes (#1023)
mmoskal Nov 1, 2024
48c3e96
add test grammars
Taka152 Nov 1, 2024
d70b849
support cuda
Taka152 Nov 1, 2024
6b90c1c
use tokenize.json to generate token_bytes
Taka152 Nov 4, 2024
bdb9ca4
fix win build
Taka152 Nov 5, 2024
a25de8e
async compute mask
Taka152 Nov 6, 2024
edc0bae
add llguidance build in cmake
Taka152 Nov 6, 2024
09861d7
update windows build
Taka152 Nov 6, 2024
ee94df8
clean cmake
Taka152 Nov 6, 2024
4d077cf
add install rust to GHA
Taka152 Nov 6, 2024
15b20b8
test action
Taka152 Nov 6, 2024
6029510
test win cpu build action
Taka152 Nov 6, 2024
4d8d8a6
update win build action
Taka152 Nov 6, 2024
346f88c
update win build action
Taka152 Nov 6, 2024
c00d8fa
update win build action
Taka152 Nov 6, 2024
39fb7ed
update win build action
Taka152 Nov 6, 2024
8038723
update win build action
Taka152 Nov 6, 2024
324f550
update win build action
Taka152 Nov 6, 2024
8722727
update win build action
Taka152 Nov 6, 2024
d620422
add rust install to workflows
Taka152 Nov 7, 2024
c1ede01
support batch infer
Taka152 Nov 7, 2024
d2f47e2
add corrosion to deps.txt
Taka152 Nov 8, 2024
8deba60
Merge branch 'main' into yingxiong/constrained_decoding
Taka152 Nov 8, 2024
b256e6d
fix merge
Taka152 Nov 8, 2024
e5d6dad
fix bugs
Taka152 Nov 8, 2024
2fd52d2
update linux gpu workflow
Taka152 Nov 8, 2024
a11684b
update linux gpu workfow
Taka152 Nov 8, 2024
56663c0
update linux gpu workflow
Taka152 Nov 8, 2024
8997064
update workflow
Taka152 Nov 8, 2024
ddda727
update workflow
Taka152 Nov 8, 2024
cb55778
update workflows
Taka152 Nov 8, 2024
4cf5b5f
add shared lib of llguidance
Taka152 Nov 15, 2024
18c2f6c
add disable_guidance option
Taka152 Nov 15, 2024
65340f2
fix format
Taka152 Nov 15, 2024
eca06f5
fix win error
Taka152 Nov 15, 2024
b306428
fix segfault
Taka152 Nov 18, 2024
df34b1e
fix segfault and move test
Taka152 Nov 18, 2024
2a5efe1
Merge remote-tracking branch 'origin/main' into yingxiong/constrained…
Taka152 Nov 18, 2024
92251d4
minor fixes
Taka152 Nov 18, 2024
56ab9ee
fix bug when is_stop
Taka152 Nov 20, 2024
03a6bb7
fixes for reviews
Taka152 Nov 20, 2024
29fc868
fix
Taka152 Nov 20, 2024
3b046c3
fix win error
Taka152 Nov 20, 2024
e9c818e
add rust env to dockerfile
Taka152 Nov 20, 2024
e06fb0a
fix dockerfile env
Taka152 Nov 20, 2024
ef141f2
update workflows
Taka152 Nov 20, 2024
13056b6
Update Rust environment in Dockerfiles
Taka152 Nov 20, 2024
22c7c37
Update Rust environment permissions in Dockerfiles
Taka152 Nov 20, 2024
a1186a5
Update Rust installation in Dockerfiles
Taka152 Nov 20, 2024
2c9b02c
revert linux arm workflow
Taka152 Nov 20, 2024
899edf9
Update Rust installation with specific version
Taka152 Nov 22, 2024
2d47c20
fix android error
Taka152 Nov 22, 2024
9a15385
fix for review
Taka152 Nov 25, 2024
ec09868
Merge remote-tracking branch 'origin/main' into yingxiong/constrained…
Taka152 Nov 27, 2024
ff94fe1
fix SetGuidance unit test
Taka152 Dec 6, 2024
88f8ef4
Merge remote-tracking branch 'origin/main' into yingxiong/constrained…
Taka152 Dec 6, 2024
4ca4075
fix format
Taka152 Dec 6, 2024
9849f65
fix to new continuous decoding api
Taka152 Dec 11, 2024
13e5100
remove comments
Taka152 Dec 12, 2024
aea5323
Merge remote-tracking branch 'origin/main' into yingxiong/constrained…
Taka152 Dec 12, 2024
9b5a6ce
fix
Taka152 Dec 12, 2024
a9390e3
fix segfault
Taka152 Dec 12, 2024
fc4b7e9
fix win build
Taka152 Dec 13, 2024
a0710d5
fix win error
Taka152 Dec 13, 2024
7d4d6bb
fix win error
Taka152 Dec 16, 2024
5d175a7
add comments
Taka152 Dec 16, 2024
52ccc8b
fix format
Taka152 Dec 16, 2024
161fcfc
fix bug
Taka152 Dec 17, 2024
ca6d86c
Merge branch 'main' into yingxiong/constrained_decoding
Taka152 Dec 17, 2024
8e03735
suuport build in ios GHA
Taka152 Dec 17, 2024
db3062b
update win azure ci
Taka152 Dec 17, 2024
0cdb4ac
update linux ci
Taka152 Dec 17, 2024
991a012
fix win ci
Taka152 Dec 17, 2024
4fdbea6
fix win ci
Taka152 Dec 17, 2024
e99ae65
fix macos arm
Taka152 Dec 17, 2024
ed96504
fix macos azure ci
Taka152 Dec 17, 2024
4cb9e55
fix for review
Taka152 Dec 18, 2024
e75166a
fix
Taka152 Dec 18, 2024
82d33e5
fix ios ci
Taka152 Dec 18, 2024
3498e0e
disable on ios ci
Taka152 Dec 18, 2024
2dbc6ee
Merge remote-tracking branch 'origin/main' into yingxiong/constrained…
Taka152 Dec 19, 2024
ce846c9
disable by default
Taka152 Dec 19, 2024
c644b26
remove azure ci code
Taka152 Dec 19, 2024
422af28
build and test with use_guidance
Taka152 Dec 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions .github/workflows/android-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,23 @@ jobs:
with:
submodules: true

- name: Install Rust Toolchain
uses: dtolnay/[email protected]

- name: Install Rust Android Toolchain
run: |
rustup target add --toolchain 1.82.0-x86_64-unknown-linux-gnu x86_64-linux-android

- name: Create Android build
run: |
set -e -x
rm -rf build
./build.sh --android --android_api=27 --android_ndk_path=${ANDROID_NDK_LATEST_HOME} --config=RelWithDebInfo --android_abi=${{ env.ANDROID_ABI }} --parallel --build_java --update
./build.sh --android --android_api=27 --android_ndk_path=${ANDROID_NDK_LATEST_HOME} --config=RelWithDebInfo --android_abi=${{ env.ANDROID_ABI }} --parallel --build_java --update --use_guidance

- name: Run Android build
run: |
set -e -x
./build.sh --android --android_api=27 --android_ndk_path=${ANDROID_NDK_LATEST_HOME} --config=RelWithDebInfo --android_abi=${{ env.ANDROID_ABI }} --parallel --build_java --build
./build.sh --android --android_api=27 --android_ndk_path=${ANDROID_NDK_LATEST_HOME} --config=RelWithDebInfo --android_abi=${{ env.ANDROID_ABI }} --parallel --build_java --build --use_guidance

- name: Enable KVM group perms so Android emulator can run
run: |
Expand Down
7 changes: 5 additions & 2 deletions .github/workflows/linux-cpu-x64-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ jobs:
with:
gradle-version: '8.6'

- name: Install Rust Toolchain
uses: dtolnay/[email protected]

- name: Get the Latest OnnxRuntime Nightly Version
shell: pwsh
run: |
Expand Down Expand Up @@ -74,8 +77,8 @@ jobs:
run: |
set -e -x
rm -rf build
cmake --preset linux_gcc_cpu_release
cmake --build --preset linux_gcc_cpu_release
cmake --preset linux_gcc_cpu_release -DUSE_GUIDANCE=ON
cmake --build --preset linux_gcc_cpu_release -DUSE_GUIDANCE=ON

- name: Install the python wheel and test dependencies
run: |
Expand Down
7 changes: 4 additions & 3 deletions .github/workflows/linux-cpu-x64-nightly-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ jobs:
- name: Checkout OnnxRuntime GenAI repo
uses: actions/checkout@v2


- name: Install Rust Toolchain
uses: dtolnay/[email protected]

- name: Download OnnxRuntime
run: |
Expand All @@ -45,8 +46,8 @@ jobs:
run: |
set -e -x
rm -rf build
cmake --preset linux_gcc_cpu_release
cmake --build --preset linux_gcc_cpu_release
cmake --preset linux_gcc_cpu_release -DUSE_GUIDANCE=ON
cmake --build --preset linux_gcc_cpu_release -DUSE_GUIDANCE=ON

- name: Install the python wheel and test dependencies
run: |
Expand Down
5 changes: 4 additions & 1 deletion .github/workflows/mac-cpu-arm64-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,16 @@ jobs:
mv ${{ env.ORT_PACKAGE_NAME }}/build/native/include ort/
mv ${{ env.ORT_PACKAGE_NAME }}/runtimes/osx-arm64/native/* ort/lib/

- name: Install Rust Toolchain
uses: dtolnay/[email protected]

- name: Configure CMake
run: |
cmake --preset macos_arm64_cpu_release

- name: Build with CMake
run: |
cmake --build --preset macos_arm64_cpu_release --parallel
cmake --build --preset macos_arm64_cpu_release --parallel -DUSE_GUIDANCE=ON
continue-on-error: false

- name: Install the python wheel and test dependencies
Expand Down
11 changes: 9 additions & 2 deletions .github/workflows/win-cpu-arm64-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,22 @@ jobs:
move ${{ env.ORT_PACKAGE_NAME }}/build/native/include ort/
move ${{ env.ORT_PACKAGE_NAME }}/runtimes/win-arm64/native/* ort/lib/

- name: Install Rust Toolchain
run: |
$exePath = "$env:TEMP\rustup-init.exe"
Taka152 marked this conversation as resolved.
Show resolved Hide resolved
(New-Object Net.WebClient).DownloadFile('https://static.rust-lang.org/rustup/dist/x86_64-pc-windows-msvc/rustup-init.exe', $exePath)
& $exePath -y --default-toolchain=1.82.0
Add-Content $env:GITHUB_PATH "$env:USERPROFILE\.cargo\bin"

- name: Configure CMake
run: |
python -m pip install wheel requests

cmake --preset windows_arm64_cpu_release
cmake --preset windows_arm64_cpu_release -DUSE_GUIDANCE=ON

- name: Build with CMake
run: |
cmake --build --preset windows_arm64_cpu_release --parallel
cmake --build --preset windows_arm64_cpu_release --parallel -DUSE_GUIDANCE=ON

- name: Install the Python Wheel and Test Dependencies
run: |
Expand Down
11 changes: 9 additions & 2 deletions .github/workflows/win-cpu-x64-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ jobs:
with:
gradle-version: '8.6'

- name: Install Rust Toolchain
run: |
$exePath = "$env:TEMP\rustup-init.exe"
(New-Object Net.WebClient).DownloadFile('https://static.rust-lang.org/rustup/dist/x86_64-pc-windows-msvc/rustup-init.exe', $exePath)
& $exePath -y --default-toolchain=1.82.0
Add-Content $env:GITHUB_PATH "$env:USERPROFILE\.cargo\bin"

- name: Download OnnxRuntime Nightly
shell: pwsh
run: |
Expand All @@ -78,11 +85,11 @@ jobs:

- name: Configure CMake
run: |
cmake --preset windows_x64_cpu_release
cmake --preset windows_x64_cpu_release -DUSE_GUIDANCE=ON

- name: Build with CMake
run: |
cmake --build --preset windows_x64_cpu_release --parallel
cmake --build --preset windows_x64_cpu_release --parallel -DUSE_GUIDANCE=ON

- name: Install the python wheel and test dependencies
run: |
Expand Down
13 changes: 10 additions & 3 deletions .github/workflows/win-cuda-x64-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,22 @@ jobs:
run: |
mkdir ort/lib
move ${{ env.ORT_PACKAGE_NAME }}/buildTransitive/native/include ort/
move ${{ env.ORT_PACKAGE_NAME }}/runtimes/win-x64/native/* ort/lib/
move ${{ env.ORT_PACKAGE_NAME }}/runtimes/win-x64/native/* ort/lib/

- name: Install Rust Toolchain
run: |
$exePath = "$env:TEMP\rustup-init.exe"
(New-Object Net.WebClient).DownloadFile('https://static.rust-lang.org/rustup/dist/x86_64-pc-windows-msvc/rustup-init.exe', $exePath)
Taka152 marked this conversation as resolved.
Show resolved Hide resolved
& $exePath -y --default-toolchain=1.82.0
Add-Content $env:GITHUB_PATH "$env:USERPROFILE\.cargo\bin"

- name: Configure CMake
run: |
cmake --preset windows_x64_cuda_release -T cuda=${{ env.cuda_dir }}\\v${{ env.cuda_version }}
cmake --preset windows_x64_cuda_release -T cuda=${{ env.cuda_dir }}\\v${{ env.cuda_version }} -DUSE_GUIDANCE=ON

- name: Build with CMake
run: |
cmake --build --preset windows_x64_cuda_release --parallel
cmake --build --preset windows_x64_cuda_release --parallel -DUSE_GUIDANCE=ON

- name: Add CUDA to PATH
run: |
Expand Down
11 changes: 9 additions & 2 deletions .github/workflows/win-directml-x64-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,20 @@ jobs:
mv $env:d3d12_dir\build\native\bin\x64\D3D12Core.dll ort\lib
mv $env:dml_dir\include\DirectML.h ort\include

- name: Install Rust Toolchain
run: |
$exePath = "$env:TEMP\rustup-init.exe"
(New-Object Net.WebClient).DownloadFile('https://static.rust-lang.org/rustup/dist/x86_64-pc-windows-msvc/rustup-init.exe', $exePath)
& $exePath -y --default-toolchain=1.82.0
Add-Content $env:GITHUB_PATH "$env:USERPROFILE\.cargo\bin"

- name: Configure CMake
run: |
cmake --preset windows_x64_directml_release -DTEST_PHI2=False
cmake --preset windows_x64_directml_release -DTEST_PHI2=False -DUSE_GUIDANCE=ON

- name: Build with CMake
run: |
cmake --build --preset windows_x64_directml_release --parallel
cmake --build --preset windows_x64_directml_release --parallel -DUSE_GUIDANCE=ON

- name: Install the Python Wheel and Test Dependencies
run: |
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ examples/csharp/HelloPhi/models

!test/test_models/hf-internal-testing/
!test/test_models/hf-internal-testing/tiny-random-gpt2*/*.onnx
!test/test_models/grammars/

.ipynb_checkpoints/
/src/java/.gradle
Expand Down
20 changes: 20 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ include(cmake/check_webgpu.cmake)
include(cmake/cxx_standard.cmake)

add_compile_definitions(BUILDING_ORT_GENAI_C)

if(USE_GUIDANCE)
add_compile_definitions(USE_GUIDANCE=1)
else()
add_compile_definitions(USE_GUIDANCE=0)
endif()

if(MSVC)
# set updated value for __cplusplus macro instead of 199711L
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/Zc:__cplusplus>)
Expand Down Expand Up @@ -142,6 +149,19 @@ if(USE_CUDA AND CMAKE_CUDA_COMPILER)
endif()
endif()


if(USE_GUIDANCE)
target_include_directories(onnxruntime-genai PUBLIC ${llguidance_SOURCE_DIR}/parser/)
target_include_directories(onnxruntime-genai-static PUBLIC ${llguidance_SOURCE_DIR}/parser/)
target_link_libraries(onnxruntime-genai PRIVATE llguidance_parser)
target_link_libraries(onnxruntime-genai-static PUBLIC llguidance_parser)
if (WIN32)
# bcrypt is needed for the rust std lib
target_link_libraries(onnxruntime-genai PRIVATE bcrypt)
target_link_libraries(onnxruntime-genai-static PRIVATE bcrypt)
endif()
endif()

if(CMAKE_GENERATOR_TOOLSET MATCHES "Visual Studio")
target_link_options(onnxruntime-genai PRIVATE "/CETCOMPAT")
target_compile_options(onnxruntime-genai PRIVATE "/sdl")
Expand Down
5 changes: 5 additions & 0 deletions build.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ class HelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescript

parser.add_argument("--use_dml", action="store_true", help="Whether to use DML. Default is to not use DML.")

parser.add_argument("--use_guidance", action="store_true", help="Whether to add guidance support. Default is False.")

# The following options are mutually exclusive (cross compiling options such as android, ios, etc.)
platform_group = parser.add_mutually_exclusive_group()
platform_group.add_argument("--android", action="store_true", help="Build for Android")
Expand Down Expand Up @@ -477,6 +479,7 @@ def update(args: argparse.Namespace, env: dict[str, str]):
f"-DUSE_DML={'ON' if args.use_dml else 'OFF'}",
f"-DENABLE_JAVA={'ON' if args.build_java else 'OFF'}",
f"-DBUILD_WHEEL={build_wheel}",
f"-DUSE_GUIDANCE={'ON' if args.use_guidance else 'OFF'}",
]

if args.ort_home:
Expand Down Expand Up @@ -535,6 +538,8 @@ def _get_opencv_toolchain_file():
"-DENABLE_TESTS=OFF",
"-DENABLE_MODEL_BENCHMARK=OFF",
]
if args.use_guidance:
command += ["-DRust_CARGO_TARGET=aarch64-apple-ios-sim"]

if args.macos == "Catalyst":
if args.cmake_generator == "Xcode":
Expand Down
2 changes: 2 additions & 0 deletions cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ googletest;https://github.com/google/googletest/archive/530d5c8c84abd2a46f38583e
microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e
onnxruntime_extensions;https://github.com/microsoft/onnxruntime-extensions.git;2c3e936cfc3401ba7ebb79d02b9e52a50439ffc3
llguidance;https://github.com/microsoft/llguidance.git;4dc358feef3cdf0542a5f95b5f4e92761887a25d
corrosion;https://github.com/corrosion-rs/corrosion.git;64289b1d79d6d19cd2e241db515381a086bb8407
16 changes: 16 additions & 0 deletions cmake/external/onnxruntime_external_deps.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,19 @@ list(APPEND EXTERNAL_LIBRARIES
ocos_operators
noexcep_operators
)

if(USE_GUIDANCE)
FetchContent_Declare(
Corrosion
GIT_REPOSITORY ${DEP_URL_corrosion}
GIT_TAG ${DEP_SHA1_corrosion}
)
onnxruntime_fetchcontent_makeavailable(Corrosion)
FetchContent_Declare(
llguidance
GIT_REPOSITORY ${DEP_URL_llguidance}
GIT_TAG ${DEP_SHA1_llguidance}
)
onnxruntime_fetchcontent_makeavailable(llguidance)
corrosion_import_crate(MANIFEST_PATH ${llguidance_SOURCE_DIR}/parser/Cargo.toml)
endif()
1 change: 1 addition & 0 deletions cmake/options.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ option(USE_CUDA "Build with CUDA support" ON)
option(USE_ROCM "Build with ROCm support" ON)
option(USE_DML "Build with DML support" OFF)
option(USE_WEBGPU "Build with WEBGPU support" ON)
option(USE_GUIDANCE "Build with guidance support" ON)

# bindings
option(ENABLE_JAVA "Build the Java API." OFF)
Expand Down
4 changes: 4 additions & 0 deletions src/cuda/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ struct CudaInterfaceImpl : CudaInterface {
cuda::LaunchHandleEOSArray(batch_logits, batch_beam_size, vocab_size, eos_token_ids, eos_token_ids_count, stream);
}

void LaunchAddLogitsMask(float* batch_logits, int batch_beam_size, int vocab_size, const uint32_t* logits_mask, cudaStream_t stream) override {
cuda::LaunchAddLogitsMask(batch_logits, batch_beam_size, vocab_size, logits_mask, stream);
}

void UpdateCacheIndirectionKernelLauncher(int32_t* tgt_indir_cache, const int32_t* src_indir_cache, const int32_t* beam_ids, int batch_size, int beam_width, int input_seq_length, int max_seq_length, int current_length, cudaStream_t stream) override {
cuda::UpdateCacheIndirectionKernelLauncher(tgt_indir_cache, src_indir_cache, beam_ids, batch_size, beam_width, input_seq_length, max_seq_length, current_length, stream);
}
Expand Down
1 change: 1 addition & 0 deletions src/cuda/interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ struct CudaInterface : DeviceInterface {
virtual void Launch_UpdateAttentionMask(int32_t* mask_data, const int32_t* old_data, int batch_beam_size, int new_kv_length, int total_length, int max_length, bool update_only, cudaStream_t stream) = 0;
virtual void Launch_UpdateAttentionMask(int64_t* mask_data, const int64_t* old_data, int batch_beam_size, int new_kv_length, int total_length, int max_length, bool update_only, cudaStream_t stream) = 0;
virtual void LaunchHandleEOSArray(float* batch_logits, int batch_beam_size, int vocab_size, const int32_t* eos_token_ids, int eos_token_ids_count, cudaStream_t stream) = 0;
virtual void LaunchAddLogitsMask(float* batch_logits, int batch_beam_size, int vocab_size, const uint32_t* logits_mask, cudaStream_t stream) = 0;
virtual void UpdateCacheIndirectionKernelLauncher(int32_t* tgt_indir_cache, const int32_t* src_indir_cache, const int32_t* beam_ids, int batch_size, int beam_width, int input_seq_length, int max_seq_length, int current_length, cudaStream_t stream) = 0;
virtual void ReorderPastStatesKernelLauncher(void* out_buffer, const void* in_buffer, int batch_size, int num_heads, int max_length, int head_size, int chunk_size, cudaStream_t stream) = 0;
virtual void LaunchCopyCrossQKSingleDecodeStep(cudaStream_t stream, float* cross_qk_buffer_data, float** qk_layer_pointers, int token_index, int batch_beam_size, int num_layers, int num_heads, int num_alignment_heads, const int* alignment_heads, int frames, int max_length) = 0;
Expand Down
16 changes: 16 additions & 0 deletions src/cuda/model_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,22 @@ void LaunchHandleEOSArray(float* batch_logits, int batch_beam_size, int vocab_si
HandleEOSArray<<<(batch_beam_size + 255) / 256, 256, 0, stream>>>(batch_logits, batch_beam_size, vocab_size, eos_token_ids, eos_token_ids_count);
}

__global__ void AddLogitsMask(float* batch_logits, int batch_beam_size, int vocab_size, const uint32_t* logits_mask) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index >= batch_beam_size * vocab_size)
return;
int batch_index = index / vocab_size;
int vocab_index = index % vocab_size;
if (!(logits_mask[(batch_index * vocab_size + vocab_index) / 32] & (1 << (vocab_index % 32))))
batch_logits[index] = std::numeric_limits<float>::lowest();
}

void LaunchAddLogitsMask(float* batch_logits, int batch_beam_size, int vocab_size, const uint32_t* logits_mask, cudaStream_t stream) {
int block_size = 256;
int num_blocks = (batch_beam_size * vocab_size + block_size - 1) / block_size;
AddLogitsMask<<<num_blocks, block_size, 0, stream>>>(batch_logits, batch_beam_size, vocab_size, logits_mask);
}

__global__ void ConvertFp16ToFp32(const half* src, float* dst, int count) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < count)
Expand Down
Loading
Loading