Function matmul_with_conj
pub fn matmul_with_conj<E>(
acc: impl As2DMut<E>,
acc_structure: BlockStructure,
lhs: impl As2D<E>,
lhs_structure: BlockStructure,
conj_lhs: Conj,
rhs: impl As2D<E>,
rhs_structure: BlockStructure,
conj_rhs: Conj,
alpha: Option<E>,
beta: E,
parallelism: Parallelism<'_>,
)where
E: ComplexField,
Expand description
Computes the matrix product [alpha * acc] + beta * lhs * rhs
(while optionally conjugating
either or both of the input matrices) and stores the result in acc
.
Performs the operation:
acc = beta * Op_lhs(lhs) * Op_rhs(rhs)
ifalpha
isNone
(in this case, the preexisting values inacc
are not read, so it is allowed to be a view over uninitialized values ifE: Copy
),acc = alpha * acc + beta * Op_lhs(lhs) * Op_rhs(rhs)
ifalpha
isSome(_)
,
The left hand side and right hand side may be interpreted as triangular depending on the given corresponding matrix structure.
For the destination matrix, the result is:
- fully computed if the structure is rectangular,
- only the triangular half (including the diagonal) is computed if the structure is triangular,
- only the strict triangular half (excluding the diagonal) is computed if the structure is strictly triangular or unit triangular.
Op_lhs
is the identity if conj_lhs
is Conj::No
, and the conjugation operation if it is
Conj::Yes
.
Op_rhs
is the identity if conj_rhs
is Conj::No
, and the conjugation operation if it is
Conj::Yes
.
§Panics
Panics if the matrix dimensions are not compatible for matrix multiplication.
i.e.
acc.nrows() == lhs.nrows()
acc.ncols() == rhs.ncols()
lhs.ncols() == rhs.nrows()
Additionally, matrices that are marked as triangular must be square, i.e., they must have the same number of rows and columns.
§Example
use faer::{
linalg::matmul::triangular::{matmul_with_conj, BlockStructure},
mat, unzipped, zipped, Conj, Mat, Parallelism,
};
let lhs = mat![[0.0, 2.0], [1.0, 3.0]];
let rhs = mat![[4.0, 6.0], [5.0, 7.0]];
let mut acc = Mat::<f64>::zeros(2, 2);
let target = mat![
[
2.5 * (lhs.read(0, 0) * rhs.read(0, 0) + lhs.read(0, 1) * rhs.read(1, 0)),
0.0,
],
[
2.5 * (lhs.read(1, 0) * rhs.read(0, 0) + lhs.read(1, 1) * rhs.read(1, 0)),
2.5 * (lhs.read(1, 0) * rhs.read(0, 1) + lhs.read(1, 1) * rhs.read(1, 1)),
],
];
matmul_with_conj(
acc.as_mut(),
BlockStructure::TriangularLower,
lhs.as_ref(),
BlockStructure::Rectangular,
Conj::No,
rhs.as_ref(),
BlockStructure::Rectangular,
Conj::No,
None,
2.5,
Parallelism::None,
);
zipped!(acc.as_ref(), target.as_ref())
.for_each(|unzipped!(acc, target)| assert!((acc.read() - target.read()).abs() < 1e-10));