test: solve with inverse

This commit is contained in:
2026-04-29 17:23:31 -06:00
parent dc4d0c8e73
commit 5ff1370d31

View File

@@ -444,17 +444,23 @@ impl Matrix {
pub fn solve_matrix_with_inverse(
&self,
solution_matrix: Matrix,
solution_matrix: Vec<Fraction>,
) -> Result<Matrix, MatrixError> {
if self.columns != self.rows {
return Err(MatrixError::NotSquared);
}
if solution_matrix.rows != self.rows || solution_matrix.columns != 1 {
if solution_matrix.len() != self.rows {
return Err(MatrixError::InvalidDataSize);
}
let solution_matrix = (self.inverse()? * solution_matrix)?;
let sol = Matrix {
rows: self.rows,
columns: 1,
data: solution_matrix,
};
let solution_matrix = (self.inverse()? * sol)?;
Ok(solution_matrix)
}
@@ -2962,4 +2968,146 @@ mod tests {
assert!(matches!(res, Err(MatrixError::NotSquared)));
}
#[test]
fn test_solve_inverse_simple() {
// 2x + y = 5
// x + y = 3
let m = Matrix {
rows: 2,
columns: 2,
data: vec![
Fraction::from(2),
Fraction::from(1),
Fraction::from(1),
Fraction::from(1),
],
};
let b = vec![Fraction::from(5), Fraction::from(3)];
let res = m.solve_matrix_with_inverse(b).unwrap();
let expected = vec![Fraction::from(2), Fraction::from(1)];
for i in 0..2 {
assert_eq!(*res.get(i, 0).unwrap(), expected[i]);
}
}
#[test]
fn test_solve_inverse_3x3() {
let m = Matrix {
rows: 3,
columns: 3,
data: vec![
Fraction::from(2),
Fraction::from(1),
Fraction::from(-1),
Fraction::from(-3),
Fraction::from(-1),
Fraction::from(2),
Fraction::from(-2),
Fraction::from(1),
Fraction::from(2),
],
};
let b = vec![Fraction::from(8), Fraction::from(-11), Fraction::from(-3)];
let res = m.solve_matrix_with_inverse(b).unwrap();
let expected = vec![Fraction::from(2), Fraction::from(3), Fraction::from(-1)];
for i in 0..3 {
assert_eq!(*res.get(i, 0).unwrap(), expected[i]);
}
}
#[test]
fn test_solve_inverse_identity() {
let m = Matrix {
rows: 2,
columns: 2,
data: vec![
Fraction::from(1),
Fraction::from(0),
Fraction::from(0),
Fraction::from(1),
],
};
let b = vec![Fraction::from(7), Fraction::from(9)];
let res = m.solve_matrix_with_inverse(b.clone()).unwrap();
for i in 0..2 {
assert_eq!(*res.get(i, 0).unwrap(), b[i]);
}
}
#[test]
fn test_solve_inverse_singular() {
let m = Matrix {
rows: 2,
columns: 2,
data: vec![
Fraction::from(1),
Fraction::from(2),
Fraction::from(2),
Fraction::from(4),
],
};
let b = vec![Fraction::from(3), Fraction::from(6)];
let res = m.solve_matrix_with_inverse(b);
assert!(matches!(res, Err(MatrixError::FailedGaussJordan)));
}
#[test]
fn test_solve_inverse_invalid_size() {
let m = Matrix::new(2, 2, Fraction::from(1)).unwrap();
let b = vec![Fraction::from(1)];
let res = m.solve_matrix_with_inverse(b);
assert!(matches!(res, Err(MatrixError::InvalidDataSize)));
}
#[test]
fn test_solve_inverse_not_squared() {
let m = Matrix::new(2, 3, Fraction::from(1)).unwrap();
let b = vec![Fraction::from(1), Fraction::from(2)];
let res = m.solve_matrix_with_inverse(b);
assert!(matches!(res, Err(MatrixError::NotSquared)));
}
#[test]
fn test_solve_inverse_vs_gauss_jordan() {
let m = Matrix {
rows: 2,
columns: 2,
data: vec![
Fraction::from(3),
Fraction::from(2),
Fraction::from(1),
Fraction::from(2),
],
};
let b = vec![Fraction::from(5), Fraction::from(5)];
let res_inv = m.solve_matrix_with_inverse(b.clone()).unwrap();
let res_gj = m.solve_matrix_with_gauss_jordan(b).unwrap();
for i in 0..2 {
assert_eq!(*res_inv.get(i, 0).unwrap(), res_gj[i]);
}
}
}