矩阵连乘——动态规划和贪心

​ 男朋友决定每天写进blog一个算法课上的示例,我不想落后,决定一起开始这个工程啦(#^.^#)。第一个当然是我比较感兴趣的矩阵连乘求最优的算法了。

1
2
3
输入:<A1,A2, ..., An>, Ai是(pi -1) * pi矩阵 

输出:计算A1 * A2 *... * An的最小代价方法
动态规划

一般给出的是动态规划算法,使用动态规划,是因为矩阵连乘可以分解为若干个具有重叠性的子问题:

​ 我们可以递归的求解每一种方式的代价,进而通过比较得到乘法次数最少的方案即最优解:
$$
A_i\times A_{i+1}\times…\times A_j = \begin{cases}(A_i) & \times(A_{i+1}\times…\times A_j)\(A_i\times A_{i+1})&\times(A_{i+2}\times…\times A_j)\…\(A_i\times…\times A_k)\times(A_{k+1}\times…\times A_j)\…\(A_i\times…\times A_{j-1})\times(A_j)\end{cases}
$$
​ 考虑到所有的k,优化解的代价方程为:

​ $m[i, j]= 0 , if: i=j $

​ $m[i, j]= min_{i<k<j}{ m[i, k]+m[k+1, j]+p_{i-1}p_kp_j} , if : i<j$

​ 由此,我们可以自底向上的递归求解该问题,第一步:求m[1,1],接着可求解m[2,2]、m[1,2],然后求m[3,3]、m[2,3]、m[1,3]、、、求解过程中,动态的求解出了每一步的最优解,最终的结果m[1,n]即是我们要找的最优解:

算法伪代码
1
2
3
4
5
6
7
8
9
FOR  i=1 TO  n DO
m[i, i]=0;
FOR l=2 TO n /* 计算 l 对角线 */
FOR i=1 TO n-l+1
j=i+l-1;
m[i, j]=∞;
FOR k = i To j-1 /* 计算m[i,j] */
q=m[i, k]+m[k+1, j]+pi-1pkpj
IF q<m[i, j] THEN m[i,j]=q;

时间复杂度:O($n^3$),空间复杂度O($n^2$)

代码示例:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
/**
* 分治算法求最优解
*
* @return 矩阵连乘的最优解所需乘法次数
*/
static void devideConquer() {
sum = 0;
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
path[i][j] = 0;
result[i][j] = infty;
if (i == j) {
result[i][j] = 0;
}
}
}
for (int l = 1; l < n; l++ ){
for (int i = 0; i <= n - l - 1; i++) {
int j = i + l;
result[i][j] = infty;
for (int k = i; k < j; k++) {
int temp = (matrixs.get(i).getX()) * (matrixs.get(k).getY()) * (matrixs.get(j).getY());
int num = result[i][k] + result[k + 1][j] + temp;
if (num < result[i][j]) {
result[i][j] = num;
path[i][j] = k;//用于记录运算的先后顺序
}
}
}
}
}
贪心算法

​ 贪心算法循环从连乘的矩阵中选出两个相邻矩阵乘法次数最小的优先进行乘法运算,并将运算结果更新到连乘的矩阵中,直到最后只有一个矩阵为止。

​ 需要注意的是,贪心算法每次找到局部最优解,但结果不一定是全局最优解。

代码示例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
/**
* 贪心算法
*/
static void greedy() {
sum = 0;
ArrayList<Pair> tempPairs = matrixs;
int nn = tempPairs.size();
int min = infty;
int tempNum = 0;
int index = 0;//记录每次循环中需要先计算的矩阵下标

while(nn > 1) {
min = infty;
index = 0;
for(int i = 0; i< nn-1; i++) {
tempNum = tempPairs.get(i).getX()*tempPairs.get(i).getY()*tempPairs.get(i+1).getY();
if(min > tempNum) {
min = tempNum;
index = i;
}
}
Pair pair = new Pair(tempPairs.get(index).getX(), tempPairs.get(index+1).getY());
tempPairs.remove(index+1);
tempPairs.remove(index);
tempPairs.add(index, pair);
nn = tempPairs.size();
sum += min;
}
System.out.println(sum);
}