1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
#![allow(non_upper_case_globals)] use crate::ffi::cuda::*; use std::ffi::{CStr}; use std::fmt; use std::os::raw::{c_char, c_int, c_void}; use std::ptr::{null, null_mut}; pub fn get_version() -> CuResult<i32> { let mut version: c_int = -1; match unsafe { cuDriverGetVersion(&mut version as *mut c_int) } { cudaError_enum_CUDA_SUCCESS => Ok(version), e => Err(CuError(e)), } } #[derive(Clone, Copy, PartialEq, Eq)] pub struct CuError(pub CUresult); impl fmt::Debug for CuError { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { write!(f, "CuError({}, name={}, desc={})", self.get_code(), self.get_name(), self.get_desc()) } } impl CuError { pub fn get_code(&self) -> u32 { self.0 as u32 } pub fn get_name(&self) -> &'static str { let mut sp: *const c_char = null(); match unsafe { cuGetErrorName(self.0, &mut sp as *mut *const c_char) } { cudaError_enum_CUDA_SUCCESS => {} cudaError_enum_CUDA_ERROR_INVALID_VALUE => { return "(invalid CUresult)"; } e => panic!("cuGetErrorName failed: {}", e), } assert!(!sp.is_null()); let s: &'static CStr = unsafe { CStr::from_ptr(sp) }; match s.to_str() { Err(_) => "(invalid utf-8)", Ok(s) => s, } } pub fn get_desc(&self) -> &'static str { let mut sp: *const c_char = null(); match unsafe { cuGetErrorString(self.0, &mut sp as *mut *const c_char) } { cudaError_enum_CUDA_SUCCESS => {} cudaError_enum_CUDA_ERROR_INVALID_VALUE => { return "(invalid CUresult)"; } e => panic!("cuGetErrorString failed: {}", e), } assert!(!sp.is_null()); let s: &'static CStr = unsafe { CStr::from_ptr(sp) }; match s.to_str() { Err(_) => "(invalid utf-8)", Ok(s) => s, } } } pub type CuResult<T=()> = Result<T, CuError>; pub fn cuda_initialized() -> bool { let mut count: i32 = 0; let result = unsafe { cuDeviceGetCount(&mut count as *mut c_int) }; match result { cudaError_enum_CUDA_SUCCESS => true, cudaError_enum_CUDA_ERROR_NOT_INITIALIZED => false, e => panic!("cuDeviceGetCount failed: {:?}", CuError(e)), } } pub fn cuda_init() { let result = unsafe { cuInit(0) }; match result { cudaError_enum_CUDA_SUCCESS => {} e => panic!("cuInit failed: {:?}", CuError(e)), } } #[derive(Debug)] pub struct CuModule { ptr: CUmodule, } impl CuModule { pub fn load<P: AsRef<CStr>>(module_path: &P) -> CuResult<CuModule> { let module_path_cstr = module_path.as_ref(); let mut ptr: CUmodule = null_mut(); match unsafe { cuModuleLoad( &mut ptr as *mut CUmodule, module_path_cstr.as_ptr()) } { cudaError_enum_CUDA_SUCCESS => Ok(CuModule{ptr}), e => Err(CuError(e)), } } pub fn load_fat_binary_image_unchecked(image: &[u8]) -> CuResult<CuModule> { let mut ptr: CUmodule = null_mut(); match unsafe { cuModuleLoadFatBinary( &mut ptr as *mut CUmodule, image.as_ptr() as *const c_void) } { cudaError_enum_CUDA_SUCCESS => Ok(CuModule{ptr}), e => Err(CuError(e)), } } pub fn unload(mut self) -> CuResult { assert!(!self.ptr.is_null()); match unsafe { cuModuleUnload(self.ptr) } { cudaError_enum_CUDA_SUCCESS => { self.ptr = null_mut(); Ok(()) } e => Err(CuError(e)), } } pub fn as_raw(&self) -> CUmodule { self.ptr } }