两个递增排序的整数序列 A, B,长度同为N,求前K个最小的 a[i] + b[j]


递增排序的整数序列 A={a(i)} B = {b(j)} 长度同为 N,两个数组相加得到 N 2 个数,再对这些数进行排序,算法时间复杂度很高啊。有什么更好的办法吗?

排序 算法

鹿过_FIGO 9 years, 7 months ago

A B -> sort asc


 i = 0, j = 0, k = 1;
while(k < K)
{
    if (A[i + 1] + B[j] < A[i] + B[j + 1]) then i++; else j++;
    k++;
}

知道才有鬼 answered 9 years, 7 months ago

用小顶堆 应该可以减少一些时间复杂度. 不过我们可以试试用搜索的办法.

建立 大小为min(N, K)的一维数组A. 对应到二维来说, A[i] 表示 对第i行 的最右的被选的点. 如此可以把 解空间 分为(已选, 未选) 两部分. 这样判定 某点(x, y)是否被选, y <= A[i] 即被选中.
: 这里可以用 N x N的 状态矩阵 来理解, 标识 "两个数组相加得到 N2 个数" 的状态, x.y表示a[x]+b[y] 的状态. 初始化 M 为: M0.0 为 已选出, M0.1 和 M1.0 为 "待选", 其他为 "未选"


 结果集R, 初始化为().

待选集S, 初始化为 (a0b0).

步骤:
1. 从 待选集S 中去掉最小的s, 把s加入 结果集R, 假设其为 a[i]+b[j];
2. 如果结果集 元素数量 = K, 退出;
2. 在M中找 Mi+1.j 和 Mi.j+1 的状态. 对某个Mx.y来说, 仅当Mx-1.y 和Mx.y-1 都为 "已选出", 才把axby加入待选集S; 使用状态数组 来确定 某个点的 状态.
3. 重复1.

时间复杂度: 待选集S 最大 为min(N, K+1). 选最小数可以用 小顶堆. 所以复杂度为 K*lg (min(N, K+1))

说明:


 仅当Mx-1.y 和Mx.y-1 都为 "已选出", 才把axby加入待选集S -- 如果Mx-1.y 和Mx.y-1有一个为"待选"
或"未选", 则待选集S中必有元素 比axby 要小.


写了java的实现, 这里 待选集 用了java内置的 PriorityQueue, 其基本操作都是logN的.


 import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.PriorityQueue;

import org.junit.Test;

public class TopKSumOfTwoSortArrayFinderV2 {

    private PriorityQueue<Node> interSet;
    private List<Node> resultList;
    private int statusArray[];

    private static class Node implements Comparable<Node>{
        int indexX;
        int indexY;
        int value;

        public Node(int indexX, int indexY, int value) {
            this.indexX = indexX;
            this.indexY = indexY;
            this.value = value;
        }

        @Override
        public int compareTo(Node n) {
            return value - n.value;
        }

        @Override
        public String toString(){
            return "x:" + indexX + ", y:" + indexY + ", value:" + value;
        }
    }

    private void getTopKSum(int[] x, int[] y, int size, int K) {
        init(x, y, size, K);

        while(true){
            System.out.println("heap size: " + interSet.size());
            Node currentNode = interSet.poll();
            resultList.add(currentNode);
            int indexX = currentNode.indexX;
            int indexY = currentNode.indexY;
            select(indexX, indexY);
            printSelected(currentNode);
            if(resultList.size() >= K) break;

            if(indexX < size - 1){ // then currentNode has right node
                Node rightNode = new Node(indexX + 1, indexY, 
                        x[indexX + 1] + y[indexY]); 
                if( rightNode.indexY == 0 || ifSelected(indexX + 1, indexY - 1)){
                    // right node has not upper node or upper node is selected 
                    interSet.add(rightNode);
                }
            }
            if(indexY < size - 1){ // then currentNode has lower node
                Node lowerNode = new Node(indexX, indexY + 1, x[indexX] 
                        + y[indexY + 1]); 
                if( lowerNode.indexX == 0 || ifSelected(indexX - 1, indexY + 1)){
                    // lower node has not left node or left node is selected 
                    interSet.add(lowerNode);
                }
            }
        }
    }

    private void printSelected(Node n) {
        System.out.println("selected: " + n);
    }

    private void select(int x, int y){
        if(y > statusArray[x])
            statusArray[x] = y;
    }

    private boolean ifSelected(int x, int y){
        return y <= statusArray[x];
    }

    private void init(int[] x, int[] y, int size, int k) {
        statusArray = new int[size > k ? size : k];
        Arrays.fill(statusArray, -1);

        interSet = new PriorityQueue<>(size);
        interSet.add(new Node(0, 0, x[0] + y[0]));

        resultList = new ArrayList<>();
    }

    @Test
    public void test(){
        int[] a={1,2,3,4,5,6,7,8};
        int[] b={100,200,300,400,500,600,700,800};
        getTopKSum(a, b, 8, 40);
    }
}

小A真是欠草 answered 9 years, 7 months ago

数据集定义

谢谢 @brayden 提供的思路。

将M(n*n)矩阵分为三个区域:

  • 已经遍历 && 已经选择(结果集R)
  • 已经遍历 && 未选择(待选集S, 使用最小堆结构)
  • 未遍历(待遍历集U)

操作

  1. 初始化,将(a0b0)放到S集(最小堆)
  2. 从S 集删除最小元素(即堆顶),最小元素放到R 结果集。(如果R足够K, 就结束)
  3. 从待遍历集U中, 选出可能是下一个或者两个 能进入R的 小元素, 放到S集中
  4. 回到步骤2 继续

图片描述

复杂度

S集最多为min(K, n). 所以时间复杂度为 O(K*log(min(K, n)))

代码示例


 #include <stdio.h>
#include <stdlib.h>
#include <math.h>

struct heap{
    int ai,bi;
    int v;
};
int S_n= 0;

void print_heap(struct heap *S){
    int i=0,j=0;
    int pad = (int)log2((float)S_n);
    int k = 1;
    printf("\n---- print S start --------\n");
    while(i<S_n){
        for(j=0; j<pad;j++){
            printf("%8c", ' ');
        }
        for(j = i+k; i<j && i<S_n; i++){
            printf("%-2d(%d,%d)\t\t", S[i].v, S[i].ai, S[i].bi);
        }
        printf("\n");
        k *=2;
        pad--;
    }
    printf("\n---- print S end --------\n");
}
void heap_swap(struct heap *a, struct heap *b){
    struct heap t;
    t=*a;
    *a=*b;
    *b=t;
}

/**
 * 向下调整堆: O(logN)
 */
void min_heap_down(struct heap *S, int i){
    int left,right;
    while(2*i+1<S_n){
        int mini;
        left = 2*i+1;
        right = left+1;
        mini = left;
        if(right < S_n && S[right].v < S[left].v){
            mini = right;
        }
        if(S[mini].v < S[i].v){
            heap_swap(S+mini, S+i);
            i = mini;
        }else{
            break;
        }
    }
}
/**
 * min_heap_insert: O(logN)
 */
void min_heap_insert(struct heap *S, int ai, int bi, int v){
    int i,j,temp;
    struct heap node = {ai, bi, v};
    S[S_n] = node;
    i = S_n;
    while(i > 0){
        j = (i-1)/2;
        if(S[j].v > S[i].v){
            heap_swap(S+i, S+j);
        }else{
            break;
        }
        i = j;
    }   
    S_n++;
}

/**
 * min_heap_del: O(logN)
 */
void min_heap_del(struct heap *S, int i){
    heap_swap(&S[i], &S[--S_n]);
    min_heap_down(S, i);
}

/**
 * topK of [a + b]
 */
int topK(int *a, int *b, int n, int K, int *R){
    int k=0;
    struct heap *S;
    S = malloc( sizeof(struct heap) * (K<n ? K :n));

    min_heap_insert(S, 0, 0, a[0]+b[0]);

    int i,j;
    while(S_n >0){
        struct heap node = S[0];
        min_heap_del(S, 0);
        R[k] = node.v;
        printf("Result: %d k=%d\n", node.v, k);

        if(++k < K){
            if(node.bi + 1 < n){
                min_heap_insert(S, node.ai, node.bi+1, a[node.ai] + b[node.bi+1]);
            }
            if(node.bi == 0 && node.ai+1 < n){
                min_heap_insert(S, node.ai+1, 0, a[node.ai+1] + b[0]);
            }
        }else{
            break;
        }
        //print_heap(S);
    };
    free(S);
    return k;
}

#define N 3
/**
 * 时间复杂度为: O(K * log(min(K,n)))
 */
int main(void) {
    int a[N] = {2,4,5};
    int b[N] = {1,2,6};
    int n=N;
    int K = 6;
    int *R = malloc( sizeof(int) * K);
    int k = topK(a, b, n, K, R);
    return 0;
}

Leopard answered 9 years, 7 months ago

Your Answer