Tensor Core GEMM

This is a simple tensor core gemm of int8

#include "device_launch_parameters.h"
#include <cstdint>
#include <cuda_runtime.h>
#include <iostream>
#include <mma.h>
#include <numeric>
#include <random>
#include <vector>

using namespace nvcuda;

// m, n, k, k is the inner dimmension
// C is in shape m x n
// we use setting 16x16x16 to do mma
// we know a warp will result in a 16x16 in C.
// reformating the C into [(m + 15) / 16, (n + 15) / 16]
// so we can get c_row and c_col for which warp to deal with
__global__ void int8_gemm_tensor_core_kernel(const int8_t *A, const int8_t *B,
                                             int32_t *C, int M, int N, int K,
                                             int lda, int ldb, int ldc,
                                             int32_t alpha, int32_t beta) {
  int warp_id = (blockIdx.x * blockDim.x + threadIdx.x) /
                32; // suppose we do not have y dim.
  // we need to know which row to deal with A.
  int how_many_col_in_tiles = (N + 15) / 16; // can be optimized
  int c_row =
      (warp_id / how_many_col_in_tiles) * 16; // warp num / reformated index
  int c_col = (warp_id % how_many_col_in_tiles) * 16;

  if (c_row >= M || c_col >= N)
    return;

  wmma::fragment<wmma::matrix_a, 16, 16, 16, int8_t, wmma::row_major> a_frag;
  wmma::fragment<wmma::matrix_b, 16, 16, 16, int8_t, wmma::col_major> b_frag;
  wmma::fragment<wmma::accumulator, 16, 16, 16, int32_t> c_frag;
  wmma::fill_fragment(c_frag, 0);

  for (int k = 0; k < K; k += 16) {
    const int8_t *a_ptr = A + c_row * lda + k; // row by c_row, k by k
    const int8_t *b_ptr = B + c_col * ldb + k;

    wmma::load_matrix_sync(a_frag, a_ptr, lda);
    wmma::load_matrix_sync(b_frag, b_ptr, ldb);
    wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
  }

  // c_ptr refer the the start of c matrix fragment
  int32_t *c_ptr = C + c_row * ldc + c_col;

  // load original C fragment
  wmma::fragment<wmma::accumulator, 16, 16, 16, int32_t> c_original_frag;
  wmma::load_matrix_sync(c_original_frag, c_ptr, ldc, wmma::mem_row_major);

  // incorporate alpha, beta
  for (int i = 0; i < c_frag.num_elements; i++) {
    c_frag.x[i] = alpha * c_frag.x[i] + beta * c_original_frag.x[i];
  }

  wmma::store_matrix_sync(c_ptr, c_frag, ldc, wmma::mem_row_major);
}

// -- CPU reference GEMM Function --
void cpu_gemm(const int8_t *A, const int8_t *B, int32_t *C, int M, int N, int K,
              int lda, int ldb, int ldc, int32_t alpha, int32_t beta) {
  std::vector<int32_t> C_temp(M * N);

  for (int i = 0; i < M; i++) {
    for (int j = 0; j < N; j++) {
      int32_t sum = 0;
      for (int k = 0; k < K; k++) {
        sum += static_cast<int32_t>(A[i * lda + k]) *
               static_cast<int32_t>(B[j * ldb + k]);
      }
      C_temp[i * N + j] = sum;
    }
  }

  for (int i = 0; i < M; i++) {
    for (int j = 0; j < N; j++) {
      int32_t original_c_val = C[i * ldc + j];
      C[i * ldc + j] = alpha * C_temp[i * N + j] + beta * original_c_val;
    }
  }
}

#define CHECK_CUDA(call)                                                       \
  do {                                                                         \
    cudaError_t err = call;                                                    \
    if (err != cudaSuccess) {                                                  \
      fprintf(stderr, "CUDA Error at %s:%d - %s\n", __FILE__, __LINE__,        \
              cudaGetErrorString(err));                                        \
      exit(EXIT_FAILURE);                                                      \
    }                                                                          \
  } while (0)

// Main function
int main() {
  // --- Matrix Dimensions (adjust as needed, K should be multiple of 16 for
  // this kernel) ---
  const int M = 256; // Rows of A, Rows of C
  const int N = 512; // Cols of B, Cols of C
  const int K =
      128; // Cols of A, Rows of B (inner dimension, must be multiple of 16)

  // Leading dimensions (assume M, N, K are actual dimensions)
  const int lda = K; // A is M x K (row-major)
  const int ldb = K; // B is K x N (column-major)
  const int ldc = N; // C is M x N (row-major)

  // Alpha and Beta values
  const int32_t alpha = 1; // Example: simple matrix multiplication
  const int32_t beta = 0;  // Example: C is initialized to 0, or overwritten

  // --- Host Memory Allocation ---
  std::vector<int8_t> h_A(M * K);
  std::vector<int8_t> h_B(K * N);
  std::vector<int32_t> h_C(M * N);     // For GPU result
  std::vector<int32_t> h_C_ref(M * N); // For CPU reference

  // --- Data Initialization ---
  // Seed random number generator
  std::mt19937 gen(0); // Use a fixed seed for reproducibility
  std::uniform_int_distribution<> distrib(-128, 127); // int8_t range

  for (int i = 0; i < M * K; ++i) {
    h_A[i] = static_cast<int8_t>(distrib(gen));
  }
  for (int i = 0; i < K * N; ++i) {
    h_B[i] = static_cast<int8_t>(distrib(gen));
  }
  // Initialize C to some non-zero values for beta test, or zeros
  for (int i = 0; i < M * N; ++i) {
    h_C[i] = static_cast<int32_t>(distrib(gen));
    h_C_ref[i] = h_C[i]; // Copy initial C for reference
  }

  // --- Device Memory Allocation ---
  int8_t *d_A, *d_B;
  int32_t *d_C;

  CHECK_CUDA(cudaMalloc((void **)&d_A, M * K * sizeof(int8_t)));
  CHECK_CUDA(cudaMalloc((void **)&d_B, K * N * sizeof(int8_t)));
  CHECK_CUDA(cudaMalloc((void **)&d_C, M * N * sizeof(int32_t)));

  // --- Data Transfer (Host to Device) ---
  CHECK_CUDA(cudaMemcpy(d_A, h_A.data(), M * K * sizeof(int8_t),
                        cudaMemcpyHostToDevice));
  CHECK_CUDA(cudaMemcpy(d_B, h_B.data(), K * N * sizeof(int8_t),
                        cudaMemcpyHostToDevice));
  CHECK_CUDA(cudaMemcpy(d_C, h_C.data(), M * N * sizeof(int32_t),
                        cudaMemcpyHostToDevice)); // Copy initial C

  // --- Kernel Launch Configuration ---
  // A warp processes a 16x16 tile of C.
  // Total number of 16x16 tiles in C: (M/16) * (N/16) (assuming M, N are
  // multiples of 16) For general M, N: ceil(M/16) * ceil(N/16)
  int num_tiles_m = (M + 15) / 16;
  int num_tiles_n = (N + 15) / 16;
  int total_warps = num_tiles_m * num_tiles_n;

  // We'll use 8 warps per block (256 threads) for good occupancy.
  const int WARPS_PER_BLOCK = 8;
  dim3 blockDim(WARPS_PER_BLOCK * 32); // 256 threads per block

  // Calculate grid dimension based on total_warps
  dim3 gridDim((total_warps + WARPS_PER_BLOCK - 1) / WARPS_PER_BLOCK);

  std::cout << "Launching kernel with:" << std::endl;
  std::cout << "  Grid Dim: (" << gridDim.x << ", " << gridDim.y << ", "
            << gridDim.z << ")" << std::endl;
  std::cout << "  Block Dim: (" << blockDim.x << ", " << blockDim.y << ", "
            << blockDim.z << ")" << std::endl;
  std::cout << "  Total Warps: " << total_warps << std::endl;

  // --- Kernel Execution ---
  int8_gemm_tensor_core_kernel<<<gridDim, blockDim>>>(
      d_A, d_B, d_C, M, N, K, lda, ldb, ldc, alpha, beta);

  CHECK_CUDA(cudaGetLastError());      // Check for errors during kernel launch
  CHECK_CUDA(cudaDeviceSynchronize()); // Wait for kernel to complete

  std::cout << "Kernel execution complete." << std::endl;

  // --- Data Transfer (Device to Host) ---
  CHECK_CUDA(cudaMemcpy(h_C.data(), d_C, M * N * sizeof(int32_t),
                        cudaMemcpyDeviceToHost));

  // --- CPU Reference Computation ---
  cpu_gemm(h_A.data(), h_B.data(), h_C_ref.data(), M, N, K, lda, ldb, ldc,
           alpha, beta);

  // --- Verification ---
  bool success = true;
  for (int i = 0; i < M * N; ++i) {
    if (h_C[i] != h_C_ref[i]) {
      std::cerr << "Mismatch at C[" << i / N << "][" << i % N
                << "]: GPU=" << h_C[i] << ", CPU=" << h_C_ref[i] << std::endl;
      success = false;
      // Print a few mismatches, then break
      if (++i % 10 == 0)
        break; // Limit error output
    }
  }

  if (success) {
    std::cout << "Verification PASSED! 🎉" << std::endl;
  } else {
    std::cerr << "Verification FAILED! 💔" << std::endl;
  }

  // --- Memory Deallocation ---
  CHECK_CUDA(cudaFree(d_A));
  CHECK_CUDA(cudaFree(d_B));
  CHECK_CUDA(cudaFree(d_C));

  std::cout << "Memory freed. Exiting." << std::endl;

  return 0;
}