use fractions::{Fraction, FractionError}; use std::fmt::{Debug, Display}; use std::ops::Add; use std::ops::Mul; use std::ops::Sub; #[derive(Debug, Eq, PartialEq)] pub enum MatrixError { IndexOutOfRange, RowOutOfRange, ColumnOutOfRange, NotSquared, InvalidDataSize, InvalidSizeForAdd, InvalidSizeForSub, InvalidSizeForMul, ZeroSize, FailedGauss, FailedGaussJordan, FractionError(FractionError), } impl From for MatrixError { fn from(err: FractionError) -> Self { MatrixError::FractionError(err) } } #[derive(PartialEq, Eq, Debug)] pub struct Matrix { rows: usize, columns: usize, data: Vec, } impl Matrix { pub fn new(rows: usize, columns: usize, default: Fraction) -> Result { if columns < 1 || rows < 1 { return Err(MatrixError::ZeroSize); } Ok(Self { rows, columns, data: vec![default; rows * columns], }) } pub fn get(&self, row: usize, column: usize) -> Result<&Fraction, MatrixError> { if row >= self.rows || column >= self.columns { return Err(MatrixError::IndexOutOfRange); } let mut index = 0; index += row * self.columns; index += column; return Ok(&self.data[index]); } pub fn get_row(&self, row: usize) -> Result, MatrixError> { if row >= self.rows { return Err(MatrixError::RowOutOfRange); } let start = row * self.columns; let end = start + self.columns; return Ok(self.data[start..end].iter()); } pub fn get_column( &self, column: usize, ) -> Result, MatrixError> { if column >= self.columns { return Err(MatrixError::ColumnOutOfRange); } Ok((0..self.rows).map(move |i| &self.data[i * self.columns + column])) } pub fn get_diagonal(&self) -> Result, MatrixError> { if self.columns != self.rows { return Err(MatrixError::NotSquared); } Ok((0..self.rows).map(move |i| &self.data[i * self.columns + i])) } pub fn set(&mut self, row: usize, column: usize, data: Fraction) -> Result<(), MatrixError> { if row >= self.rows || column >= self.columns { return Err(MatrixError::IndexOutOfRange); } let mut index = 0; index += row * self.columns; index += column; self.data[index] = data; Ok(()) } pub fn set_row(&mut self, row: usize, data: Vec) -> Result<(), MatrixError> { if row >= self.rows { return Err(MatrixError::IndexOutOfRange); } if data.len() != self.columns { return Err(MatrixError::InvalidDataSize); } for i in 0..data.len() { self.set(row, i, data[i])?; } Ok(()) } pub fn set_column(&mut self, column: usize, data: Vec) -> Result<(), MatrixError> { if column >= self.columns { return Err(MatrixError::ColumnOutOfRange); } if data.len() != self.rows { return Err(MatrixError::InvalidDataSize); } for i in 0..data.len() { self.set(i, column, data[i])?; } Ok(()) } pub fn exchange_rows(&mut self, row1: usize, row2: usize) -> Result<(), MatrixError> { if row1 >= self.rows || row2 >= self.rows { return Err(MatrixError::RowOutOfRange); } let start1 = row1 * self.columns; let start2 = row2 * self.columns; for i in 0..self.columns { self.data.swap(start1 + i, start2 + i); } Ok(()) } pub fn exchange_columns(&mut self, column1: usize, column2: usize) -> Result<(), MatrixError> { if column1 >= self.columns || column2 >= self.columns { return Err(MatrixError::ColumnOutOfRange); } for i in 0..self.rows { let idx1 = column1 + i * self.columns; let idx2 = column2 + i * self.columns; self.data.swap(idx1, idx2); } Ok(()) } pub fn insert_column(&mut self, index: usize, data: Vec) -> Result<(), MatrixError> { if index >= self.columns { return Err(MatrixError::ColumnOutOfRange); } if data.len() != self.rows { return Err(MatrixError::InvalidDataSize); } for i in 0..self.rows { self.data.insert((i * self.columns) + index + i, data[i]); } self.columns += 1; Ok(()) } pub fn insert_rows(&mut self, index: usize, data: Vec) -> Result<(), MatrixError> { if index >= self.rows { return Err(MatrixError::RowOutOfRange); } if data.len() != self.columns { return Err(MatrixError::InvalidDataSize); } for i in 0..self.columns { self.data.insert((i * self.columns) + i, data[i]); } self.rows += 1; Ok(()) } fn partial_pivoting(&mut self, col: usize, sign: &mut Fraction) -> Result { if col >= self.columns { return Err(MatrixError::ColumnOutOfRange); } let mut max_row = col; let mut max_value = self.get(col, col).unwrap().abs(); for r in (col + 1)..self.rows { let val = self.get(r, col).unwrap().abs(); if val > max_value { max_value = val; max_row = r; } } if max_value.is_zero() { return Ok(false); } if max_row != col { self.exchange_rows(col, max_row)?; *sign = -*sign; } Ok(true) } pub fn gaussian_elimination(&self) -> Result<(Matrix, Fraction), MatrixError> { let mut trig_matrix = Matrix { columns: self.columns, rows: self.rows, data: self.data.clone(), }; let mut sign = Fraction::new(1, 1).unwrap(); for i in 0..self.columns { // We do partial pivoting for better efifiency and security let pivot_exists = trig_matrix.partial_pivoting(i, &mut sign)?; // If there ain't no other thing but 0 then we're // fucked, determinant is zero if !pivot_exists { return Err(MatrixError::FailedGauss); } let pivot = *trig_matrix.get(i, i).unwrap(); // The main gaussian elimination, not even I remember how // i did it in such a asimple way for x in (i + 1)..trig_matrix.rows { let m = (*trig_matrix.get(x, i).unwrap() / pivot).unwrap(); let row_x = trig_matrix.get_row(x)?; let row_i = trig_matrix.get_row(i)?; let new_row = row_x .zip(row_i) .map(|(a, b)| *a - m * *b) .collect::>(); trig_matrix.set_row(x, new_row)?; } } Ok((trig_matrix, sign)) } pub fn gauss_jordan_elimination(&self) -> Result { let mut new_matrix = Matrix { columns: self.columns, rows: self.rows, data: self.data.clone(), }; let mut dummy = Fraction::from(1); for i in 0..self.columns { let pivot_exists = new_matrix.partial_pivoting(i, &mut dummy)?; if !pivot_exists { return Err(MatrixError::FailedGaussJordan); } let pivot = *new_matrix.get(i, i).unwrap(); let new_pivot_row = new_matrix .get_row(i)? .map(|x| *x / pivot) .collect::, _>>()?; new_matrix.set_row(i, new_pivot_row)?; for r in 0..new_matrix.rows { if r == i { continue; } let factor = *new_matrix.get(r, i).unwrap(); if factor.is_zero() { continue; } let new_row_normalized = new_matrix .get_row(r)? .zip(new_matrix.get_row(i)?) .map(|(a, b)| *a - factor * *b) .collect::>(); new_matrix.set_row(r, new_row_normalized)?; } } Ok(new_matrix) } pub fn get_determinant(&self) -> Result { if self.rows != self.columns { return Err(MatrixError::NotSquared); } let (trig_matrix, sign) = match self.gaussian_elimination() { Err(MatrixError::FailedGauss) => return Ok(Fraction::new(0, 1).unwrap()), Ok((matrix, sign)) => (matrix, sign), Err(err) => return Err(err), }; // YES, now we got ourselves a triangular matrix, now we just // take the product of the diagonal and multiply by sign, that's // the determinant :) let determinant = sign * trig_matrix .get_diagonal()? .copied() .fold(Fraction::from(1i64), |acc, x| acc * x); return Ok(determinant); } } impl Add for Matrix { type Output = Result; fn add(self, other: Self) -> Self::Output { if self.data.len() != other.data.len() { return Err(MatrixError::InvalidSizeForAdd); } let mut new_data = Vec::new(); for i in 0..self.data.len() { new_data.push(self.data[i] + other.data[i]); } Ok(Matrix { columns: self.columns, rows: self.rows, data: new_data, }) } } impl Sub for Matrix { type Output = Result; fn sub(self, other: Self) -> Self::Output { if self.data.len() != other.data.len() { return Err(MatrixError::InvalidSizeForSub); } let mut new_data = Vec::new(); for i in 0..self.data.len() { new_data.push(self.data[i] - other.data[i]); } Ok(Matrix { columns: self.columns, rows: self.rows, data: new_data, }) } } impl Mul for Matrix { type Output = Result; fn mul(self, other: Self) -> Self::Output { if self.columns != other.rows { return Err(MatrixError::InvalidSizeForMul); } let mut new_data: Vec = Vec::new(); for i in 0..self.rows { for k in 0..other.columns { let current_column = other.get_column(k)?; let current_row = self.get_row(i)?; let mut new_value = Fraction::new(0, 1).unwrap(); for (a, b) in current_row.zip(current_column) { new_value = new_value + (*a * *b); } new_data.push(new_value); } } Ok(Matrix { rows: self.rows, columns: other.columns, data: new_data, }) } } impl Display for Matrix { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let mut display = String::new(); let mut index = 0; for _i in 0..self.columns { display += "{"; for _k in 0..self.rows { display += &format!(" {},", self.data[index]); index += 1; } display += " }\n"; } write!(f, "{}", display) } } #[cfg(test)] mod tests { use super::*; }