diff --git a/src/lib.rs b/src/lib.rs index 751118d..7c02ff7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ -use fractions::Fraction; +use fractions::{Fraction, FractionError}; use std::fmt::{Debug, Display}; use std::ops::Add; use std::ops::Mul; @@ -15,6 +15,15 @@ pub enum MatrixError { InvalidSizeForSub, InvalidSizeForMul, ZeroSize, + FailedGauss, + FailedGaussJordan, + FractionError(FractionError), +} + +impl From for MatrixError { + fn from(err: FractionError) -> Self { + MatrixError::FractionError(err) + } } #[derive(PartialEq, Eq, Debug)] @@ -148,6 +157,37 @@ impl Matrix { None } + fn partial_pivoting(&mut self, col: usize, sign: &mut Fraction) -> Result { + if col >= self.columns { + return Err(MatrixError::ColumnOutOfRange); + } + + let mut max_row = col; + let mut max_value = self.get(col, col).unwrap().abs(); + + for r in (col + 1)..self.rows { + let val = self.get(r, col).unwrap().abs(); + if val > max_value { + max_value = val; + max_row = r; + } + } + + if max_value.is_zero() { + return Ok(false); + } + + if max_row != col { + match self.exchange_rows(col, max_row) { + Some(err) => return Err(err), + None => {} + }; + *sign = -*sign; + } + + Ok(true) + } + pub fn gaussian_elimination(&self) -> Result<(Matrix, Fraction), MatrixError> { let mut trig_matrix = Matrix { columns: self.columns, @@ -157,28 +197,13 @@ impl Matrix { let mut sign = Fraction::new(1, 1).unwrap(); for i in 0..self.columns { - let mut max_row = i; - let mut max_value = trig_matrix.get(i, i).unwrap().abs(); - - // We do parcial pivoting to avoid getting insane - // numbers that may result in overflow with fractions - for r in (i + 1)..self.rows { - let val = trig_matrix.get(r, i).unwrap().abs(); - if val > max_value { - max_value = val; - max_row = r; - } - } + // We do partial pivoting for better efifiency and security + let pivot_exists = trig_matrix.partial_pivoting(i, &mut sign)?; // If there ain't no other thing but 0 then we're // fucked, determinant is zero - if max_value.is_zero() { - return Ok((trig_matrix, Fraction::new(0, 1).unwrap())); - } - - if max_row != i { - trig_matrix.exchange_rows(i, max_row); - sign = -sign; + if !pivot_exists { + return Err(MatrixError::FailedGauss); } let pivot = *trig_matrix.get(i, i).unwrap(); @@ -203,13 +228,65 @@ impl Matrix { Ok((trig_matrix, sign)) } + pub fn gauss_jordan_elimination(&self) -> Result { + let mut new_matrix = Matrix { + columns: self.columns, + rows: self.rows, + data: self.data.clone(), + }; + + let mut dummy = Fraction::from(1); + + for i in 0..self.columns { + let pivot_exists = new_matrix.partial_pivoting(i, &mut dummy)?; + + if !pivot_exists { + return Err(MatrixError::FailedGaussJordan); + } + + let pivot = *new_matrix.get(i, i).unwrap(); + + let new_pivot_row = new_matrix + .get_row(i)? + .map(|x| *x / pivot) + .collect::, _>>()?; + + new_matrix.set_row(i, new_pivot_row); + + for r in 0..new_matrix.rows { + if r == 1 { + continue; + } + + let factor = *new_matrix.get(r, i).unwrap(); + + if factor.is_zero() { + continue; + } + + let new_row_normalized = new_matrix + .get_row(r)? + .zip(new_matrix.get_row(i)?) + .map(|(a, b)| *a - factor * *b) + .collect::>(); + + new_matrix.set_row(r, new_row_normalized); + } + } + + Ok(new_matrix) + } + pub fn get_determinant(&self) -> Result { if self.rows != self.columns { return Err(MatrixError::NotSquared); } - let (trig_matrix, sign) = self.gaussian_elimination()?; - + let (trig_matrix, sign) = match self.gaussian_elimination() { + Err(MatrixError::FailedGauss) => return Ok(Fraction::new(0, 1).unwrap()), + Ok((matrix, sign)) => (matrix, sign), + Err(err) => return Err(err), + }; // YES, now we got ourselves a triangular matrix, now we just // take the product of the diagonal and multiply by sign, that's // the determinant :)