-
Notifications
You must be signed in to change notification settings - Fork 4k
[ROCm] add ROCm support (pt. 2) #7039
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jeffdaily
wants to merge
19
commits into
microsoft:master
Choose a base branch
from
jeffdaily:rocm3
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+162
−18
Open
Changes from 15 commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
61ec4f1
[ROCm] re-add support for ROCm builds
jeffdaily e461e86
Fix error in cpp_tests/test_arrow.cpp.
jeffdaily 06741cc
update for ROCm 7 BC-breaking change to warpSize
jeffdaily e601565
lint
jeffdaily 1b3deb5
Revert "Fix error in cpp_tests/test_arrow.cpp."
jeffdaily 20996c9
partial revert of 61ec4f1aa215ca3381e7b79e98f002dc0c021d77
jeffdaily 3eafff7
add --use-rocm option to build-python.sh
jeffdaily 9655896
fix cuda build missing CUDASUCCESS_OR_FATAL in vector_cudahost.h
jeffdaily 0aa90c5
add rocm docs
jeffdaily 1b550a7
fix doc using pre-commit
jeffdaily 2e98916
apply reviewer suggestions
jeffdaily 76953f1
Merge branch 'master' into rocm3
jameslamb 5acc678
fix build-python.sh doc
jeffdaily 552fafc
Merge branch 'master' into rocm3
shiyu1994 6453a04
Merge branch 'master' into rocm3
StrikerRUS 5437ce4
fix build for rocm 7.0
jeffdaily 2f7bd8e
Merge branch 'master' into rocm3
StrikerRUS 548cec8
Merge branch 'master' into rocm3
jeffdaily 08ce574
Merge branch 'master' into rocm3
StrikerRUS File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -742,6 +742,65 @@ macOS | |
|
|
||
| The CUDA version is not supported on macOS. | ||
|
|
||
| Build ROCm Version | ||
| ~~~~~~~~~~~~~~~~~~ | ||
|
|
||
| The `original GPU version <#build-gpu-version>`__ of LightGBM (``device_type=gpu``) is based on OpenCL. | ||
|
|
||
| The ROCm-based version (``device_type=cuda``) is a separate implementation. Yes, the ROCm version reuses the ``device_type=cuda`` as a convenience for users. Use this version in Linux environments with an AMD GPU. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm ok with this, given that PyTorch does the same thing: https://docs.pytorch.org/docs/stable/notes/hip.html (no action required, just commenting for other maintainers to see) |
||
|
|
||
| Windows | ||
| ^^^^^^^ | ||
|
|
||
| The ROCm version is not supported on Windows. | ||
| Use the `GPU version <#build-gpu-version>`__ (``device_type=gpu``) for GPU acceleration on Windows. | ||
|
|
||
| Linux | ||
| ^^^^^ | ||
|
|
||
| On Linux, a ROCm version of LightGBM can be built using | ||
|
|
||
| - **CMake**, **gcc** and **ROCm**; | ||
| - **CMake**, **Clang** and **ROCm**. | ||
|
|
||
| Please refer to `the ROCm docs`_ for **ROCm** libraries installation. | ||
|
|
||
| After compilation the executable and ``.so`` files will be in ``LightGBM/`` folder. | ||
|
|
||
| gcc | ||
| *** | ||
|
|
||
| 1. Install `CMake`_, **gcc** and **ROCm**. | ||
|
|
||
| 2. Run the following commands: | ||
|
|
||
| .. code:: sh | ||
|
|
||
| git clone --recursive https://github.com/microsoft/LightGBM | ||
| cd LightGBM | ||
| cmake -B build -S . -DUSE_ROCM=ON | ||
| cmake --build build -j4 | ||
|
|
||
| Clang | ||
| ***** | ||
|
|
||
| 1. Install `CMake`_, **Clang**, **OpenMP** and **ROCm**. | ||
|
|
||
| 2. Run the following commands: | ||
|
|
||
| .. code:: sh | ||
|
|
||
| git clone --recursive https://github.com/microsoft/LightGBM | ||
| cd LightGBM | ||
| export CXX=clang++-14 CC=clang-14 # replace "14" with version of Clang installed on your machine | ||
| cmake -B build -S . -DUSE_ROCM=ON | ||
| cmake --build build -j4 | ||
|
|
||
| macOS | ||
| ^^^^^ | ||
|
|
||
| The ROCm version is not supported on macOS. | ||
|
|
||
| Build Java Wrapper | ||
| ~~~~~~~~~~~~~~~~~~ | ||
|
|
||
|
|
@@ -1051,6 +1110,8 @@ gcc | |
|
|
||
| .. _this detailed guide: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html | ||
|
|
||
| .. _the ROCm docs: https://rocm.docs.amd.com/projects/install-on-linux/en/latest/ | ||
|
|
||
| .. _following docs: https://github.com/google/sanitizers/wiki | ||
|
|
||
| .. _Ninja: https://ninja-build.org | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,20 +1,68 @@ | ||
| /*! | ||
| * Copyright(C) 2023 Advanced Micro Devices, Inc. All rights reserved. | ||
| */ | ||
| #ifndef LIGHTGBM_CUDA_CUDA_ROCM_INTEROP_H_ | ||
| #define LIGHTGBM_CUDA_CUDA_ROCM_INTEROP_H_ | ||
|
|
||
| #ifdef USE_CUDA | ||
|
|
||
| #if defined(__HIP_PLATFORM_AMD__) || defined(__HIP__) | ||
| #if defined(__HIP_PLATFORM_AMD__) | ||
|
|
||
| // ROCm doesn't have __shfl_down_sync, only __shfl_down without mask. | ||
| // Since mask is full 0xffffffff, we can use __shfl_down instead. | ||
| #define __shfl_down_sync(mask, val, offset) __shfl_down(val, offset) | ||
| #define __shfl_up_sync(mask, val, offset) __shfl_up(val, offset) | ||
| // ROCm warpSize is constexpr and is either 32 or 64 depending on gfx arch. | ||
| #define WARPSIZE warpSize | ||
|
|
||
| // ROCm doesn't have atomicAdd_block, but it should be semantically the same as atomicAdd | ||
| #define atomicAdd_block atomicAdd | ||
| #else | ||
|
|
||
| // hipify | ||
| #include <hip/hip_runtime.h> | ||
| #define cudaDeviceProp hipDeviceProp_t | ||
| #define cudaDeviceSynchronize hipDeviceSynchronize | ||
| #define cudaError_t hipError_t | ||
| #define cudaFree hipFree | ||
| #define cudaFreeHost hipFreeHost | ||
| #define cudaGetDevice hipGetDevice | ||
| #define cudaGetDeviceProperties hipGetDeviceProperties | ||
| #define cudaGetErrorName hipGetErrorName | ||
| #define cudaGetErrorString hipGetErrorString | ||
| #define cudaGetLastError hipGetLastError | ||
| #define cudaHostAlloc hipHostAlloc | ||
| #define cudaHostAllocPortable hipHostAllocPortable | ||
| #define cudaMalloc hipMalloc | ||
| #define cudaMemcpy hipMemcpy | ||
| #define cudaMemcpyAsync hipMemcpyAsync | ||
| #define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice | ||
| #define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost | ||
| #define cudaMemcpyHostToDevice hipMemcpyHostToDevice | ||
| #define cudaMemoryTypeHost hipMemoryTypeHost | ||
| #define cudaMemset hipMemset | ||
| #define cudaPointerAttributes hipPointerAttribute_t | ||
| #define cudaPointerGetAttributes hipPointerGetAttributes | ||
| #define cudaSetDevice hipSetDevice | ||
| #define cudaStreamCreate hipStreamCreate | ||
| #define cudaStreamDestroy hipStreamDestroy | ||
| #define cudaStream_t hipStream_t | ||
| #define cudaSuccess hipSuccess | ||
|
|
||
| // warpSize is only allowed for device code. | ||
| // HIP header used to define warpSize as a constexpr that was either 32 or 64 | ||
| // depending on the target device, and then always set it to 64 for host code. | ||
| static inline constexpr int WARP_SIZE_INTERNAL() { | ||
| #if defined(__GFX9__) | ||
| return 64; | ||
| #else // __GFX9__ | ||
| return 32; | ||
| #endif // __GFX9__ | ||
| } | ||
| #define WARPSIZE (WARP_SIZE_INTERNAL()) | ||
|
|
||
| #else // __HIP_PLATFORM_AMD__ | ||
| // CUDA warpSize is not a constexpr, but always 32 | ||
| #define WARPSIZE 32 | ||
| #endif | ||
|
|
||
| #endif | ||
| #endif // USE_CUDA | ||
|
|
||
| #endif // LIGHTGBM_CUDA_CUDA_ROCM_INTEROP_H_ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.