1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#![allow(non_upper_case_globals)]

use crate::ffi::cuda::*;

use std::ffi::{CStr};
use std::fmt;
use std::os::raw::{c_char, c_int, c_void};
use std::ptr::{null, null_mut};

pub fn get_version() -> CuResult<i32> {
  let mut version: c_int = -1;
  match unsafe { cuDriverGetVersion(&mut version as *mut c_int) } {
    cudaError_enum_CUDA_SUCCESS => Ok(version),
    e => Err(CuError(e)),
  }
}

#[derive(Clone, Copy, PartialEq, Eq)]
pub struct CuError(pub CUresult);

impl fmt::Debug for CuError {
  fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
    write!(f, "CuError({}, name={}, desc={})",
        self.get_code(), self.get_name(), self.get_desc())
  }
}

impl CuError {
  pub fn get_code(&self) -> u32 {
    self.0 as u32
  }

  pub fn get_name(&self) -> &'static str {
    let mut sp: *const c_char = null();
    match unsafe { cuGetErrorName(self.0, &mut sp as *mut *const c_char) } {
      cudaError_enum_CUDA_SUCCESS => {}
      cudaError_enum_CUDA_ERROR_INVALID_VALUE => {
        return "(invalid CUresult)";
      }
      e => panic!("cuGetErrorName failed: {}", e),
    }
    assert!(!sp.is_null());
    let s: &'static CStr = unsafe { CStr::from_ptr(sp) };
    match s.to_str() {
      Err(_) => "(invalid utf-8)",
      Ok(s) => s,
    }
  }

  pub fn get_desc(&self) -> &'static str {
    let mut sp: *const c_char = null();
    match unsafe { cuGetErrorString(self.0, &mut sp as *mut *const c_char) } {
      cudaError_enum_CUDA_SUCCESS => {}
      cudaError_enum_CUDA_ERROR_INVALID_VALUE => {
        return "(invalid CUresult)";
      }
      e => panic!("cuGetErrorString failed: {}", e),
    }
    assert!(!sp.is_null());
    let s: &'static CStr = unsafe { CStr::from_ptr(sp) };
    match s.to_str() {
      Err(_) => "(invalid utf-8)",
      Ok(s) => s,
    }
  }
}

pub type CuResult<T=()> = Result<T, CuError>;

pub fn cuda_initialized() -> bool {
  let mut count: i32 = 0;
  let result = unsafe { cuDeviceGetCount(&mut count as *mut c_int) };
  match result {
    cudaError_enum_CUDA_SUCCESS => true,
    cudaError_enum_CUDA_ERROR_NOT_INITIALIZED => false,
    e => panic!("cuDeviceGetCount failed: {:?}", CuError(e)),
  }
}

pub fn cuda_init() {
  let result = unsafe { cuInit(0) };
  match result {
    cudaError_enum_CUDA_SUCCESS => {}
    e => panic!("cuInit failed: {:?}", CuError(e)),
  }
}

#[derive(Debug)]
pub struct CuModule {
  ptr:  CUmodule,
}

impl CuModule {
  pub fn load<P: AsRef<CStr>>(module_path: &P) -> CuResult<CuModule> {
    let module_path_cstr = module_path.as_ref();
    let mut ptr: CUmodule = null_mut();
    match unsafe { cuModuleLoad(
        &mut ptr as *mut CUmodule,
        module_path_cstr.as_ptr()) }
    {
      cudaError_enum_CUDA_SUCCESS => Ok(CuModule{ptr}),
      e => Err(CuError(e)),
    }
  }

  pub fn load_fat_binary_image_unchecked(image: &[u8]) -> CuResult<CuModule> {
    let mut ptr: CUmodule = null_mut();
    match unsafe { cuModuleLoadFatBinary(
        &mut ptr as *mut CUmodule,
        image.as_ptr() as *const c_void) }
    {
      cudaError_enum_CUDA_SUCCESS => Ok(CuModule{ptr}),
      e => Err(CuError(e)),
    }
  }

  pub fn unload(mut self) -> CuResult {
    assert!(!self.ptr.is_null());
    match unsafe { cuModuleUnload(self.ptr) } {
      cudaError_enum_CUDA_SUCCESS => {
        self.ptr = null_mut();
        Ok(())
      }
      e => Err(CuError(e)),
    }
  }

  pub fn as_raw(&self) -> CUmodule {
    self.ptr
  }
}