diff --git a/src/lib.rs b/src/lib.rs index 4996191..362a29e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -470,7 +470,7 @@ impl Add for Matrix { type Output = Result; fn add(self, other: Self) -> Self::Output { - if self.data.len() != other.data.len() { + if self.rows != other.rows || self.columns != other.columns { return Err(MatrixError::InvalidSizeForAdd); } @@ -491,7 +491,7 @@ impl Sub for Matrix { type Output = Result; fn sub(self, other: Self) -> Self::Output { - if self.data.len() != other.data.len() { + if self.rows != other.rows || self.columns != other.columns { return Err(MatrixError::InvalidSizeForSub); } @@ -3110,4 +3110,135 @@ mod tests { assert_eq!(*res_inv.get(i, 0).unwrap(), res_gj[i]); } } + + #[test] + fn test_add_basic() { + let a = Matrix { + rows: 2, + columns: 2, + data: vec![ + Fraction::from(1), + Fraction::from(2), + Fraction::from(3), + Fraction::from(4), + ], + }; + + let b = Matrix { + rows: 2, + columns: 2, + data: vec![ + Fraction::from(5), + Fraction::from(6), + Fraction::from(7), + Fraction::from(8), + ], + }; + + let result = (a + b).unwrap(); + + let expected = vec![ + Fraction::from(6), + Fraction::from(8), + Fraction::from(10), + Fraction::from(12), + ]; + + assert_eq!(result.data, expected); + } + + #[test] + fn test_add_negative_values() { + let a = Matrix { + rows: 2, + columns: 2, + data: vec![ + Fraction::from(-1), + Fraction::from(2), + Fraction::from(3), + Fraction::from(-4), + ], + }; + + let b = Matrix { + rows: 2, + columns: 2, + data: vec![ + Fraction::from(1), + Fraction::from(-2), + Fraction::from(-3), + Fraction::from(4), + ], + }; + + let result = (a + b).unwrap(); + + assert!(result.data.iter().all(|x| x.is_zero())); + } + + #[test] + fn test_add_commutative() { + let a = Matrix { + rows: 2, + columns: 2, + data: vec![ + Fraction::from(1), + Fraction::from(2), + Fraction::from(3), + Fraction::from(4), + ], + }; + + let b = Matrix { + rows: 2, + columns: 2, + data: vec![ + Fraction::from(5), + Fraction::from(6), + Fraction::from(7), + Fraction::from(8), + ], + }; + + let res1 = (a.clone() + b.clone()).unwrap(); + let res2 = (b + a).unwrap(); + + assert_eq!(res1.data, res2.data); + } + + #[test] + fn test_add_associative() { + let a = Matrix::new(2, 2, Fraction::from(1)).unwrap(); + let b = Matrix::new(2, 2, Fraction::from(2)).unwrap(); + let c = Matrix::new(2, 2, Fraction::from(3)).unwrap(); + + let res1 = ((a.clone() + b.clone()).unwrap() + c.clone()).unwrap(); + let res2 = (a + (b + c).unwrap()).unwrap(); + + assert_eq!(res1.data, res2.data); + } + + #[test] + fn test_add_fractions() { + let a = Matrix { + rows: 1, + columns: 2, + data: vec![Fraction::new(1, 2).unwrap(), Fraction::new(1, 3).unwrap()], + }; + + let b = Matrix { + rows: 1, + columns: 2, + data: vec![Fraction::new(1, 2).unwrap(), Fraction::new(2, 3).unwrap()], + }; + + let result: Matrix = (a + b).unwrap(); + + let expected = vec![ + Fraction::from(1), // 1/2 + 1/2 + Fraction::from(1), // 1/3 + 2/3 + ]; + + assert_eq!(result.data, expected); + } }