diff --git a/src/lib.rs b/src/lib.rs index 8012e57..0f0e76f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -920,4 +920,103 @@ mod tests { assert_eq!(col1, col2); } + + #[test] + fn test_get_diagonal_valid() { + let m = Matrix { + rows: 3, + columns: 3, + data: vec![ + Fraction::from(1), + Fraction::from(2), + Fraction::from(3), + Fraction::from(4), + Fraction::from(5), + Fraction::from(6), + Fraction::from(7), + Fraction::from(8), + Fraction::from(9), + ], + }; + + let diag: Vec = m.get_diagonal().unwrap().cloned().collect(); + + assert_eq!( + diag, + vec![Fraction::from(1), Fraction::from(5), Fraction::from(9)] + ); + } + + #[test] + fn test_get_diagonal_single_element() { + let m = Matrix::new(1, 1, Fraction::from(42)).unwrap(); + + let diag: Vec = m.get_diagonal().unwrap().cloned().collect(); + + assert_eq!(diag, vec![Fraction::from(42)]); + } + + #[test] + fn test_get_diagonal_identity_like() { + let mut data = vec![Fraction::from(0); 9]; + data[0] = Fraction::from(1); + data[4] = Fraction::from(1); + data[8] = Fraction::from(1); + + let m = Matrix { + rows: 3, + columns: 3, + data, + }; + + let diag: Vec = m.get_diagonal().unwrap().cloned().collect(); + + assert_eq!( + diag, + vec![Fraction::from(1), Fraction::from(1), Fraction::from(1)] + ); + } + + #[test] + fn test_get_diagonal_not_squared() { + let m = Matrix::new(2, 3, Fraction::from(0)).unwrap(); + + let result = m.get_diagonal(); + + assert!(matches!(result, Err(MatrixError::NotSquared))); + } + + #[test] + fn test_get_diagonal_matches_get() { + let m = Matrix { + rows: 4, + columns: 4, + data: (0..16).map(|x| Fraction::from(x)).collect(), + }; + + let diag = m.get_diagonal().unwrap(); + + for (i, val) in diag.enumerate() { + assert_eq!(*val, *m.get(i, i).unwrap()); + } + } + + #[test] + fn test_get_diagonal_length() { + let m = Matrix::new(5, 5, Fraction::from(3)).unwrap(); + + let diag: Vec = m.get_diagonal().unwrap().cloned().collect(); + + assert_eq!(diag.len(), 5); + } + + #[test] + fn test_get_diagonal_iterator_repeatability() { + let m = Matrix::new(3, 3, Fraction::from(7)).unwrap(); + + let d1: Vec = m.get_diagonal().unwrap().cloned().collect(); + let d2: Vec = m.get_diagonal().unwrap().cloned().collect(); + + assert_eq!(d1, d2); + } }