Function matmul
pub fn matmul<E, LhsE, RhsE>(
acc: impl As2DMut<E>,
acc_structure: BlockStructure,
lhs: impl As2D<LhsE>,
lhs_structure: BlockStructure,
rhs: impl As2D<RhsE>,
rhs_structure: BlockStructure,
alpha: Option<E>,
beta: E,
parallelism: Parallelism<'_>,
)
Expand description
Computes the matrix product [alpha * acc] + beta * lhs * rhs
and stores the result in
acc
.
Performs the operation:
acc = beta * lhs * rhs
ifalpha
isNone
(in this case, the preexisting values inacc
are not read, so it is allowed to be a view over uninitialized values ifE: Copy
),acc = alpha * acc + beta * lhs * rhs
ifalpha
isSome(_)
,
The left hand side and right hand side may be interpreted as triangular depending on the given corresponding matrix structure.
For the destination matrix, the result is:
- fully computed if the structure is rectangular,
- only the triangular half (including the diagonal) is computed if the structure is triangular,
- only the strict triangular half (excluding the diagonal) is computed if the structure is strictly triangular or unit triangular.
§Panics
Panics if the matrix dimensions are not compatible for matrix multiplication.
i.e.
acc.nrows() == lhs.nrows()
acc.ncols() == rhs.ncols()
lhs.ncols() == rhs.nrows()
Additionally, matrices that are marked as triangular must be square, i.e., they must have the same number of rows and columns.
§Example
use faer::{
linalg::matmul::triangular::{matmul, BlockStructure},
mat, unzipped, zipped, Conj, Mat, Parallelism,
};
let lhs = mat![[0.0, 2.0], [1.0, 3.0]];
let rhs = mat![[4.0, 6.0], [5.0, 7.0]];
let mut acc = Mat::<f64>::zeros(2, 2);
let target = mat![
[
2.5 * (lhs.read(0, 0) * rhs.read(0, 0) + lhs.read(0, 1) * rhs.read(1, 0)),
0.0,
],
[
2.5 * (lhs.read(1, 0) * rhs.read(0, 0) + lhs.read(1, 1) * rhs.read(1, 0)),
2.5 * (lhs.read(1, 0) * rhs.read(0, 1) + lhs.read(1, 1) * rhs.read(1, 1)),
],
];
matmul(
acc.as_mut(),
BlockStructure::TriangularLower,
lhs.as_ref(),
BlockStructure::Rectangular,
rhs.as_ref(),
BlockStructure::Rectangular,
None,
2.5,
Parallelism::None,
);
zipped!(acc.as_ref(), target.as_ref())
.for_each(|unzipped!(acc, target)| assert!((acc.read() - target.read()).abs() < 1e-10));