View Javadoc
1   package org.djunits.util;
2   
3   import org.djunits.vecmat.NonInvertibleMatrixException;
4   
5   /**
6    * MatrixMath implements a number of methods for linear algebra operations on square matrices, such as LU decomposition,
7    * inverse, trace, etc.
8    * <p>
9    * Copyright (c) 2025-2026 Delft University of Technology, Jaffalaan 5, 2628 BX Delft, the Netherlands. All rights reserved. See
10   * for project information <a href="https://djunits.org" target="_blank">https://djunits.org</a>. The DJUNITS project is
11   * distributed under a <a href="https://djunits.org/docs/license.html" target="_blank">three-clause BSD-style license</a>.
12   * @author Alexander Verbraeck
13   */
14  @SuppressWarnings({"checkstyle:needbraces", "checkstyle:localvariablename"})
15  public final class MatrixMath
16  {
17      /** The default tolerance for operations when no tolerance is given. */
18      protected static final double DEFAULT_TOL = 1e-12;
19  
20      /** */
21      private MatrixMath()
22      {
23      }
24  
25      // ---------- Helpers ----------
26  
27      /**
28       * Return the index in a row-major storage of a square matrix: [a11, a12, ..., a21, a22, ..., ann].
29       * @param n the order of the square matrix
30       * @param r the row to look up (0-based)
31       * @param c the column to look up (0-based)
32       * @return the index in the array for row r, column c
33       */
34      private static int idx(final int n, final int r, final int c)
35      {
36          return r * n + c;
37      }
38  
39      // ---------- Multiplication ----------
40  
41      /**
42       * Multiply A (m x n, row-major) with B (n x p, row-major) to produce C (m x p, row-major). Storage: row-major means A[i,k]
43       * is at aSi[i * n + k], B[k,j] at bSi[k * p + j], and C[i,j] at result[i * p + j].
44       * @param aSi matrix A, length must be m * n, stored as row-major double[]
45       * @param bSi matrix B, length must be n * p, stored as row-major double[]
46       * @param m rows of A (and C)
47       * @param n columns of A == rows of B
48       * @param p columns of B (and C)
49       * @return C = A * B, as row-major double[] (length m * p)
50       * @throws IllegalArgumentException if input lengths are inconsistent
51       */
52      @SuppressWarnings("checkstyle:needbraces")
53      public static double[] multiply(final double[] aSi, final double[] bSi, final int m, final int n, final int p)
54      {
55          if (aSi.length != m * n)
56          {
57              throw new IllegalArgumentException("A length " + aSi.length + " != m*n (" + (m * n) + ")");
58          }
59          if (bSi.length != n * p)
60          {
61              throw new IllegalArgumentException("B length " + bSi.length + " != n*p (" + (n * p) + ")");
62          }
63  
64          final double[] result = new double[m * p];
65  
66          // Loop order: i (row of C/A), k (shared dim), j (column of C/B)
67          // Rationale:
68          // - A[i,k] is contiguous in k -> hoist aik
69          // - B[k,j] is contiguous in j for fixed k -> inner loop over j is cache-friendly
70          // - C[i,*] row is contiguous -> row-wise accumulation
71          for (int i = 0; i < m; i++)
72          {
73              final int aiBase = i * n; // start of A's row i
74              final int ciBase = i * p; // start of C's row i
75              for (int k = 0; k < n; k++)
76              {
77                  final double aik = aSi[aiBase + k]; // A[i,k]
78                  final int bkBase = k * p; // start of B's row k
79                  for (int j = 0; j < p; j++)
80                  {
81                      result[ciBase + j] += aik * bSi[bkBase + j]; // C[i,j] += A[i,k] * B[k,j]
82                  }
83              }
84          }
85          return result;
86      }
87  
88      // ---------- Basic invariants ----------
89  
90      /**
91       * Calculate the trace of the matrix.
92       * @param aSi the row-major storage of the matrix
93       * @param n the order of the matrix
94       * @return the trace of the matrix
95       */
96      public static double trace(final double[] aSi, final int n)
97      {
98          double t = 0.0;
99          for (int i = 0; i < n; i++)
100             t += aSi[idx(n, i, i)];
101         return t;
102     }
103 
104     /**
105      * Return whether the matrix is symmetric, using a default tolerance.
106      * @param aSi the row-major storage of the matrix
107      * @param n the order of the matrix
108      * @return whether the matrix is symmetric
109      */
110     public static boolean isSymmetric(final double[] aSi, final int n)
111     {
112         return isSymmetric(aSi, n, DEFAULT_TOL);
113     }
114 
115     /**
116      * Return whether the matrix is symmetric, within the given tolerance.
117      * @param aSi the row-major storage of the matrix
118      * @param n the order of the matrix
119      * @param tol the tolerance in SI units
120      * @return whether the matrix is symmetric
121      */
122     public static boolean isSymmetric(final double[] aSi, final int n, final double tol)
123     {
124         for (int i = 0; i < n; i++)
125         {
126             for (int j = i + 1; j < n; j++)
127             {
128                 double aij = aSi[idx(n, i, j)];
129                 double aji = aSi[idx(n, j, i)];
130                 if (Math.abs(aij - aji) > tol)
131                     return false;
132             }
133         }
134         return true;
135     }
136 
137     /**
138      * Return whether the matrix is skew-symmetric, using a default tolerance.
139      * @param aSi the row-major storage of the matrix
140      * @param n the order of the matrix
141      * @return whether the matrix is symmetric
142      */
143     public static boolean isSkewSymmetric(final double[] aSi, final int n)
144     {
145         return isSkewSymmetric(aSi, n, DEFAULT_TOL);
146     }
147 
148     /**
149      * Return whether the matrix is skew-symmetric, within the given tolerance.
150      * @param aSi the row-major storage of the matrix
151      * @param n the order of the matrix
152      * @param tol the tolerance in SI units
153      * @return whether the matrix is symmetric
154      */
155     public static boolean isSkewSymmetric(final double[] aSi, final int n, final double tol)
156     {
157         for (int i = 0; i < n; i++)
158         {
159             if (Math.abs(aSi[idx(n, i, i)]) > tol)
160                 return false; // diagonal must be ~0
161             for (int j = i + 1; j < n; j++)
162             {
163                 double aij = aSi[idx(n, i, j)];
164                 double aji = aSi[idx(n, j, i)];
165                 if (Math.abs(aij + aji) > tol)
166                     return false; // a_ij = -a_ji
167             }
168         }
169         return true;
170     }
171 
172     // ---------- LU decomposition with partial pivoting ----------
173 
174     /**
175      * Helper class for LU decomposition with partial pivoting.
176      */
177     protected static final class LU
178     {
179         /** combined L (unit diag) and U, row-major. */
180         private final double[] lu;
181 
182         /** row permutations. */
183         private final int[] piv;
184 
185         /** the pivot sign, +1 or -1. */
186         private final int pivotSign;
187 
188         /** scale for tolerance decisions. */
189         private final double scale;
190 
191         /**
192          * Construct an LU instance.
193          * @param lu combined L (unit diag) and U, row-major
194          * @param piv row permutations
195          * @param pivotSign the pivot sign, +1 or -1
196          * @param scale scale for tolerance decisions
197          */
198         LU(final double[] lu, final int[] piv, final int pivotSign, final double scale)
199         {
200             this.lu = lu;
201             this.piv = piv;
202             this.pivotSign = pivotSign;
203             this.scale = scale;
204         }
205     }
206 
207     /**
208      * Decompose.
209      * @param a the row-major storage of the matrix
210      * @param n the order of the square matrix
211      * @return an LU object containing L and U in one array
212      */
213     protected static LU luDecompose(final double[] a, final int n)
214     {
215         double[] lu = a.clone();
216         int[] piv = new int[n];
217         for (int i = 0; i < n; i++)
218             piv[i] = i;
219         int pivotSign = 1;
220         double scale = Math2.maxAbs(a);
221 
222         for (int k = 0; k < n; k++)
223         {
224             // Find pivot
225             int p = k;
226             double max = Math.abs(lu[idx(n, k, k)]);
227             for (int i = k + 1; i < n; i++)
228             {
229                 double v = Math.abs(lu[idx(n, i, k)]);
230                 if (v > max)
231                 {
232                     max = v;
233                     p = i;
234                 }
235             }
236             // Swap rows if needed
237             if (p != k)
238             {
239                 for (int j = 0; j < n; j++)
240                 {
241                     double tmp = lu[idx(n, k, j)];
242                     lu[idx(n, k, j)] = lu[idx(n, p, j)];
243                     lu[idx(n, p, j)] = tmp;
244                 }
245                 int tmpi = piv[k];
246                 piv[k] = piv[p];
247                 piv[p] = tmpi;
248                 pivotSign = -pivotSign;
249             }
250 
251             double pivot = lu[idx(n, k, k)];
252             if (pivot != 0.0)
253             {
254                 // Compute multipliers
255                 for (int i = k + 1; i < n; i++)
256                 {
257                     lu[idx(n, i, k)] /= pivot;
258                 }
259                 // Rank-1 update to the trailing submatrix
260                 for (int i = k + 1; i < n; i++)
261                 {
262                     double lik = lu[idx(n, i, k)];
263                     if (lik == 0.0)
264                         continue;
265                     for (int j = k + 1; j < n; j++)
266                     {
267                         lu[idx(n, i, j)] -= lik * lu[idx(n, k, j)];
268                     }
269                 }
270             }
271             // If pivot == 0, we still continue; this indicates singular/deficient rank.
272         }
273         return new LU(lu, piv, pivotSign, scale);
274     }
275 
276     /**
277      * Determine whether the matrix is singular, based on the LU decomposition.
278      * @param luRes The LU result
279      * @param n the order of the matrix
280      * @param relTol the relative tolerance
281      * @return whether the matrix is singular
282      */
283     protected static boolean isSingularFromLU(final LU luRes, final int n, final double relTol)
284     {
285         double tol = Math.max(1.0, luRes.scale) * relTol;
286         for (int i = 0; i < n; i++)
287         {
288             if (Math.abs(luRes.lu[idx(n, i, i)]) <= tol)
289                 return true;
290         }
291         return false;
292     }
293 
294     /**
295      * Return the determinant, based on the LU decomposition.
296      * @param luRes The LU result
297      * @param n the order of the matrix
298      * @return the determinant of the matrix
299      */
300     protected static double detFromLU(final LU luRes, final int n)
301     {
302         double det = luRes.pivotSign;
303         for (int i = 0; i < n; i++)
304         {
305             det *= luRes.lu[idx(n, i, i)];
306         }
307         return det;
308     }
309 
310     /**
311      * Solve LU x = b for one right-hand side vector b (vector solve).
312      * @param luRes The LU result
313      * @param n the order of the matrix
314      * @param b the right-hand side
315      */
316     protected static void luSolveInPlace(final LU luRes, final int n, final double[] b)
317     {
318         // Apply row permutations to b
319         double[] bp = b.clone();
320         for (int i = 0; i < n; i++)
321         {
322             b[i] = bp[luRes.piv[i]];
323         }
324         // Forward substitution: solve L y = Pb
325         for (int i = 0; i < n; i++)
326         {
327             double sum = b[i];
328             for (int j = 0; j < i; j++)
329             {
330                 sum -= luRes.lu[idx(n, i, j)] * b[j];
331             }
332             b[i] = sum; // L has unit diagonal
333         }
334         // Back substitution: solve U x = y
335         for (int i = n - 1; i >= 0; i--)
336         {
337             double sum = b[i];
338             for (int j = i + 1; j < n; j++)
339             {
340                 sum -= luRes.lu[idx(n, i, j)] * b[j];
341             }
342             b[i] = sum / luRes.lu[idx(n, i, i)];
343         }
344     }
345 
346     // ---------- Determinant ----------
347 
348     /**
349      * Calculate the determinant, based on the role of Sarrus. See
350      * <a href="https://en.wikipedia.org/wiki/Rule_of_Sarrus">https://en.wikipedia.org/wiki/Rule_of_Sarrus</a>.
351      * @param aSi the row-major storage of the matrix
352      * @param n the order of the matrix
353      * @return the determinant
354      */
355     public static double determinant(final double[] aSi, final int n)
356     {
357         if (n == 1)
358             return aSi[0];
359         if (n == 2)
360         {
361             return aSi[0] * aSi[3] - aSi[1] * aSi[2];
362         }
363         if (n == 3)
364         {
365             // Sarrus
366             double a = aSi[0], b = aSi[1], c = aSi[2];
367             double d = aSi[3], e = aSi[4], f = aSi[5];
368             double g = aSi[6], h = aSi[7], i = aSi[8];
369             return a * (e * i - f * h) - b * (d * i - f * g) + c * (d * h - e * g);
370         }
371         LU luRes = luDecompose(aSi, n);
372         return detFromLU(luRes, n);
373     }
374 
375     // ---------- Inverse ----------
376 
377     /**
378      * Calculate the inverse. Fast methods for n=1, 2, 3. For higher order matrices, the calculation is based on the LU
379      * decomposition.
380      * @param aSi the row-major storage of the matrix
381      * @param n the order of the matrix
382      * @return the inverse of the matrix
383      * @throws NonInvertibleMatrixException when the matrix cannot be inverted
384      */
385     public static double[] inverse(final double[] aSi, final int n) throws NonInvertibleMatrixException
386     {
387         if (n == 1)
388         {
389             double v = aSi[0];
390             if (v == 0.0)
391                 throw new NonInvertibleMatrixException("Singular 1x1 matrix");
392             return new double[] {1.0 / v};
393         }
394         if (n == 2)
395         {
396             double a = aSi[0], b = aSi[1], c = aSi[2], d = aSi[3];
397             double det = a * d - b * c;
398             if (Math.abs(det) <= DEFAULT_TOL * Math.max(1.0, Math2.maxAbs(aSi)))
399             {
400                 throw new NonInvertibleMatrixException("Singular 2x2 matrix");
401             }
402             double invDet = 1.0 / det;
403             double[] inv = new double[] {d * invDet, -b * invDet, -c * invDet, a * invDet};
404             return inv;
405         }
406         if (n == 3)
407         {
408             // Use adj(A)^T / det(A)
409             double a = aSi[0], b = aSi[1], c = aSi[2];
410             double d = aSi[3], e = aSi[4], f = aSi[5];
411             double g = aSi[6], h = aSi[7], i = aSi[8];
412 
413             double A = (e * i - f * h);
414             double B = -(d * i - f * g);
415             double C = (d * h - e * g);
416             double D = -(b * i - c * h);
417             double E = (a * i - c * g);
418             double F = -(a * h - b * g);
419             double G = (b * f - c * e);
420             double H = -(a * f - c * d);
421             double I = (a * e - b * d);
422 
423             double det = a * A + b * B + c * C;
424             if (Math.abs(det) <= DEFAULT_TOL * Math.max(1.0, Math2.maxAbs(aSi)))
425             {
426                 throw new NonInvertibleMatrixException("Singular 3x3 matrix");
427             }
428             double invDet = 1.0 / det;
429             // inverse = adj(A)^T / det = cof(A) / det (since we computed cofactors already in place)
430             double[] inv = new double[] {A * invDet, D * invDet, G * invDet, B * invDet, E * invDet, H * invDet, C * invDet,
431                     F * invDet, I * invDet};
432             return inv;
433         }
434 
435         // General n: LU + solve for identity
436         LU luRes = luDecompose(aSi, n);
437         if (isSingularFromLU(luRes, n, DEFAULT_TOL))
438         {
439             throw new NonInvertibleMatrixException("Matrix is singular to working precision");
440         }
441         double[] inv = new double[n * n];
442         double[] e = new double[n]; // RHS basis vector
443         for (int col = 0; col < n; col++)
444         {
445             // e = unit vector
446             java.util.Arrays.fill(e, 0.0);
447             e[col] = 1.0;
448             double[] x = e.clone();
449             luSolveInPlace(luRes, n, x);
450             // write column to inv (row-major target)
451             for (int row = 0; row < n; row++)
452             {
453                 inv[idx(n, row, col)] = x[row];
454             }
455         }
456         return inv;
457     }
458 
459     // ---------- Adjugate (cofactor transpose) ----------
460 
461     /**
462      * Calculate the adjugate. Fast methods for n=1, 2, 3.
463      * @param aSi the row-major storage of the matrix
464      * @param n the order of the matrix
465      * @return the adjugate of the matrix
466      */
467     public static double[] adjugate(final double[] aSi, final int n)
468     {
469         if (n == 1)
470         {
471             return new double[] {1.0};
472         }
473         if (n == 2)
474         {
475             // adj([a b; c d]) = [ d -b; -c a ]
476             double a = aSi[0], b = aSi[1], c = aSi[2], d = aSi[3];
477             double[] adj = new double[] {d, -b, -c, a};
478             return adj;
479         }
480         if (n == 3)
481         {
482             double a = aSi[0], b = aSi[1], c = aSi[2];
483             double d = aSi[3], e = aSi[4], f = aSi[5];
484             double g = aSi[6], h = aSi[7], i = aSi[8];
485             // Cofactor matrix (not transposed yet)
486             double C00 = (e * i - f * h);
487             double C01 = -(d * i - f * g);
488             double C02 = (d * h - e * g);
489             double C10 = -(b * i - c * h);
490             double C11 = (a * i - c * g);
491             double C12 = -(a * h - b * g);
492             double C20 = (b * f - c * e);
493             double C21 = -(a * f - c * d);
494             double C22 = (a * e - b * d);
495             // Adjugate = Cofactor^T
496             double[] adj = new double[] {C00, C10, C20, C01, C11, C21, C02, C12, C22};
497             return adj;
498         }
499 
500         // General n: build cofactor matrix via minors, then transpose
501         int m = n - 1;
502         double[] cof = new double[n * n];
503         double[] minor = new double[m * m];
504 
505         for (int r = 0; r < n; r++)
506         {
507             for (int c = 0; c < n; c++)
508             {
509                 // Build minor excluding row r and col c
510                 int p = 0;
511                 for (int i = 0; i < n; i++)
512                 {
513                     if (i == r)
514                         continue;
515                     for (int j = 0; j < n; j++)
516                     {
517                         if (j == c)
518                             continue;
519                         minor[p++] = aSi[idx(n, i, j)];
520                     }
521                 }
522                 double detMinor;
523                 // note that m=1 and m=2 are not possible because they have been captured by n=1, n=2 and n=3
524                 if (m == 3)
525                 {
526                     double A = minor[0], B = minor[1], C = minor[2];
527                     double D = minor[3], E = minor[4], F = minor[5];
528                     double G = minor[6], H = minor[7], I = minor[8];
529                     detMinor = A * (E * I - F * H) - B * (D * I - F * G) + C * (D * H - E * G);
530                 }
531                 else
532                 {
533                     // Use LU for larger minors
534                     LU luMinor = luDecompose(minor, m);
535                     detMinor = detFromLU(luMinor, m);
536                 }
537                 double sign = ((r + c) & 1) == 0 ? 1.0 : -1.0;
538                 // Store cofactor (not yet transposed)
539                 cof[idx(n, r, c)] = sign * detMinor;
540             }
541         }
542         // Adjugate = cof^T
543         double[] adj = new double[n * n];
544         for (int r = 0; r < n; r++)
545         {
546             for (int c = 0; c < n; c++)
547             {
548                 adj[idx(n, r, c)] = cof[idx(n, c, r)];
549             }
550         }
551         return adj;
552     }
553 
554 }