From 5ff1370d318d713d0a22820eca6dafbaf5f5aec6 Mon Sep 17 00:00:00 2001 From: laentropia Date: Wed, 29 Apr 2026 17:23:31 -0600 Subject: [PATCH] test: solve with inverse --- src/lib.rs | 154 +++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 151 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index dd98891..4996191 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -444,17 +444,23 @@ impl Matrix { pub fn solve_matrix_with_inverse( &self, - solution_matrix: Matrix, + solution_matrix: Vec, ) -> Result { if self.columns != self.rows { return Err(MatrixError::NotSquared); } - if solution_matrix.rows != self.rows || solution_matrix.columns != 1 { + if solution_matrix.len() != self.rows { return Err(MatrixError::InvalidDataSize); } - let solution_matrix = (self.inverse()? * solution_matrix)?; + let sol = Matrix { + rows: self.rows, + columns: 1, + data: solution_matrix, + }; + + let solution_matrix = (self.inverse()? * sol)?; Ok(solution_matrix) } @@ -2962,4 +2968,146 @@ mod tests { assert!(matches!(res, Err(MatrixError::NotSquared))); } + + #[test] + fn test_solve_inverse_simple() { + // 2x + y = 5 + // x + y = 3 + let m = Matrix { + rows: 2, + columns: 2, + data: vec![ + Fraction::from(2), + Fraction::from(1), + Fraction::from(1), + Fraction::from(1), + ], + }; + + let b = vec![Fraction::from(5), Fraction::from(3)]; + + let res = m.solve_matrix_with_inverse(b).unwrap(); + + let expected = vec![Fraction::from(2), Fraction::from(1)]; + + for i in 0..2 { + assert_eq!(*res.get(i, 0).unwrap(), expected[i]); + } + } + + #[test] + fn test_solve_inverse_3x3() { + let m = Matrix { + rows: 3, + columns: 3, + data: vec![ + Fraction::from(2), + Fraction::from(1), + Fraction::from(-1), + Fraction::from(-3), + Fraction::from(-1), + Fraction::from(2), + Fraction::from(-2), + Fraction::from(1), + Fraction::from(2), + ], + }; + + let b = vec![Fraction::from(8), Fraction::from(-11), Fraction::from(-3)]; + + let res = m.solve_matrix_with_inverse(b).unwrap(); + + let expected = vec![Fraction::from(2), Fraction::from(3), Fraction::from(-1)]; + + for i in 0..3 { + assert_eq!(*res.get(i, 0).unwrap(), expected[i]); + } + } + + #[test] + fn test_solve_inverse_identity() { + let m = Matrix { + rows: 2, + columns: 2, + data: vec![ + Fraction::from(1), + Fraction::from(0), + Fraction::from(0), + Fraction::from(1), + ], + }; + + let b = vec![Fraction::from(7), Fraction::from(9)]; + + let res = m.solve_matrix_with_inverse(b.clone()).unwrap(); + + for i in 0..2 { + assert_eq!(*res.get(i, 0).unwrap(), b[i]); + } + } + + #[test] + fn test_solve_inverse_singular() { + let m = Matrix { + rows: 2, + columns: 2, + data: vec![ + Fraction::from(1), + Fraction::from(2), + Fraction::from(2), + Fraction::from(4), + ], + }; + + let b = vec![Fraction::from(3), Fraction::from(6)]; + + let res = m.solve_matrix_with_inverse(b); + + assert!(matches!(res, Err(MatrixError::FailedGaussJordan))); + } + + #[test] + fn test_solve_inverse_invalid_size() { + let m = Matrix::new(2, 2, Fraction::from(1)).unwrap(); + + let b = vec![Fraction::from(1)]; + + let res = m.solve_matrix_with_inverse(b); + + assert!(matches!(res, Err(MatrixError::InvalidDataSize))); + } + + #[test] + fn test_solve_inverse_not_squared() { + let m = Matrix::new(2, 3, Fraction::from(1)).unwrap(); + + let b = vec![Fraction::from(1), Fraction::from(2)]; + + let res = m.solve_matrix_with_inverse(b); + + assert!(matches!(res, Err(MatrixError::NotSquared))); + } + + #[test] + fn test_solve_inverse_vs_gauss_jordan() { + let m = Matrix { + rows: 2, + columns: 2, + data: vec![ + Fraction::from(3), + Fraction::from(2), + Fraction::from(1), + Fraction::from(2), + ], + }; + + let b = vec![Fraction::from(5), Fraction::from(5)]; + + let res_inv = m.solve_matrix_with_inverse(b.clone()).unwrap(); + let res_gj = m.solve_matrix_with_gauss_jordan(b).unwrap(); + + for i in 0..2 { + assert_eq!(*res_inv.get(i, 0).unwrap(), res_gj[i]); + } + } }