跳转至

小美的仓库整理#
发布于2021-04-16
上次编辑2021-04-16

问题描述#

小美是美团仓库的管理员,她会根据单据的要求按顺序取出仓库中的货物,每取出一件货物后会把剩余货物重新堆放,使得自己方便查找。已知货物入库的时候是按顺序堆放在一起的。如果小美取出其中一件货物,则会把货物所在的一堆物品以取出的货物为界分成两堆,这样可以保证货物局部的顺序不变。
已知货物最初是按 1~n 的顺序堆放的,每件货物的重量为 w[i] ,小美会根据单据依次不放回的取出货物。请问根据上述操作,小美每取出一件货物之后,重量和最大的一堆货物重量是多少?

格式:

输入:
- 输入第一行包含一个正整数 n ,表示货物的数量。
- 输入第二行包含 n 个正整数,表示 1~n 号货物的重量 w[i] 。
- 输入第三行有 n 个数,表示小美按顺序取出的货物的编号,也就是一个 1~n 的全排列。
输出:
- 输出包含 n 行,每行一个整数,表示每取出一件货物以后,对于重量和最大的一堆货物,其重量和为多少。

示例:

输入:
     5
     3 2 4 4 5 
     4 3 5 2 1
输出:
     9
     5
     5
     3
     0
解释:
原本的状态是 {{3,2,4,4,5}} ,取出 4 号货物后,得到 {{3,2,4},{5}} ,第一堆货物的和是 9 ,然后取出 3 号货物得到 {{3,2}{5}} ,此时第一堆和第二堆的和都是 5 ,以此类推。

提示:

  • 1 <= n,m <= 50000
  • 1 <= w[i] <= 100

解题思路#

方法一:线段树#

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import sys
from itertools import accumulate
from typing import List

readline = sys.stdin.readline


def readint():
    return int(readline().strip())


def readstr():
    return readline().strip()


def readints():
    return list(map(int, readline().strip().split()))


class Node:
    def __init__(self, start: int, end: int, max: int):
        self.start, self.end, self.max = start, end, max
        self.left, self.right = None, None


def update(root: Node, k: int, presum: List[int]) -> None:
    if root.start < k < root.end and not root.left and not root.right:
        # left: [root->start+1,k-1], right: [k+1,root->end-1]
        leftsum = (k - 1 >= 0 and presum[k - 1]) - (
            root.start >= 0 and presum[root.start]
        )
        rightsum = (root.end >= 0 and presum[root.end - 1]) - (
            k >= 0 and presum[k]
        )
        root.left = Node(root.start, k, leftsum)
        root.right = Node(k, root.end, rightsum)
    elif root.left and k < root.left.end:
        update(root.left, k, presum)
    else:
        update(root.right, k, presum)
    root.max = max(root.left.max, root.right.max)


n = readint()
presum = list(accumulate(readints()))
s = readints()
root = Node(-1, n, presum[n - 1])
for k in s:
    update(root, k - 1, presum)
    print(root.max)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#include <bits/stdc++.h>
using namespace std;
struct Node {
    int start, end, max;
    Node *left, *right;
    Node(int start, int end, int max)
        : start(start), end(end), max(max), left(nullptr), right(nullptr){};
};
void update(Node *root, int k, vector<int> &presum) {
    if (root->start < k && k < root->end && !root->left && !root->right) {
        // left: [root->start+1,k-1], right: [k+1,root->end-1]
        int leftsum = (k - 1 >= 0 ? presum[k - 1] : 0) -
                      (root->start >= 0 ? presum[root->start] : 0);
        int rightsum = (root->end - 1 >= 0 ? presum[root->end - 1] : 0) -
                       (k >= 0 ? presum[k] : 0);
        root->left = new Node(root->start, k, leftsum);
        root->right = new Node(k, root->end, rightsum);
    } else if (root->left && k < root->left->end) {
        update(root->left, k, presum);
    } else {
        update(root->right, k, presum);
    }
    root->max = max(root->left->max, root->right->max);
}
int main() {
    ios_base::sync_with_stdio(false), cin.tie(nullptr);
    int n, w, k, sum = 0;
    cin >> n;
    vector<int> presum(n);
    for (int i = 0; i < n; ++i) {
        cin >> w;
        sum += w;
        presum[i] = sum;
    }
    Node *root = new Node(-1, n, presum[n - 1]);
    for (int i = 0; i < n; ++i) {
        cin >> k;
        update(root, k - 1, presum);
        cout << root->max << '\n';
    }
    return 0;
}

但是最差情况下的复杂度是 \(\mathcal{O}(n^)\),如

1
2
3
50000
1 2 ... 50000
1 2 ... 50000

这组数据就会超时,只是题目数据比较弱所以能够通过。

方法二:反向添加+并查集#

按照与取货相反的顺序,不断添加货物。

然后使用并查集,将添加的第 \(x\) 个货物与其左 \(x-1\)\(x+1\) 两边的货物堆进行合并,合并的前提是:

  1. \(x-1\)\(x+1\) 的下标合法
  2. \(x-1\)\(x+1\) 已经存在(即已经添加过)

然后在添加和合并的过程中维护最大值即可。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import sys

readline = sys.stdin.readline


def readint():
    return int(readline().strip())


def readstr():
    return readline().strip()


def readints():
    return list(map(int, readline().strip().split()))


n = readint()
w = readints()
o = readints()

mx = 0
fa = list(range(n))
size = [1] * n
sum = [0] * n
ans = [0] * n


def find_set(x: int) -> int:
    if x != fa[x]:
        fa[x] = find_set(fa[x])
    return fa[x]


def union_sets(x: int, y: int) -> None:
    global mx
    if 0 <= y < n:
        if sum[y]:
            x, y = find_set(x), find_set(y)
            if size[x] > size[y]:
                x, y = y, x
            fa[x] = y
            size[y] += size[x]
            sum[y] += sum[x]
            mx = max(mx, sum[y])


for i in range(n - 1, -1, -1):
    ans[i] = mx
    x = o[i] - 1
    sum[x] = w[x]
    mx = max(mx, sum[x])
    union_sets(x, x - 1)
    union_sets(x, x + 1)

print("\n".join(map(str, ans)))
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#include <bits/stdc++.h>
using namespace std;
int main() {
    ios_base::sync_with_stdio(false), cin.tie(nullptr);
    int n, mx = 0;
    cin >> n;
    vector<int> w(n), o(n), ans(n);
    vector<int> fa(n), size(n), sum(n);
    for (int i = 0; i < n; ++i) {
        cin >> w[i];
        fa[i] = i, size[i] = 1, sum[i] = 0;
    }
    function<int(int)> find_set = [&](int x) {
        return x == fa[x] ? x : (fa[x] = find_set(fa[x]));
    };
    function<void(int, int)> union_sets = [&](int x, int y) {
        if (0 <= y && y < n) {
            x = find_set(x), y = find_set(y);
            if (sum[y]) {
                if (size[x] > size[y]) {
                    swap(x, y);
                }
                fa[x] = y;
                size[y] += size[x];
                sum[y] += sum[x];
                mx = max(mx, sum[y]);
            }
        }
    };
    for (int i = 0; i < n; ++i) {
        cin >> o[i];
        o[i] -= 1;
    }
    for (int i = n - 1; i >= 0; --i) {
        ans[i] = mx;
        int x = o[i];
        sum[x] = w[x];
        mx = max(mx, sum[x]);
        union_sets(x, x - 1);
        union_sets(x, x + 1);
    }
    for (int i = 0; i < n; ++i) {
        cout << ans[i] << '\n';
    }
    return 0;
}
返回顶部

在手机上阅读