小美的仓库整理
发布于2021-04-16
上次编辑2021-04-16
问题描述
小美是美团仓库的管理员,她会根据单据的要求按顺序取出仓库中的货物,每取出一件货物后会把剩余货物重新堆放,使得自己方便查找。已知货物入库的时候是按顺序堆放在一起的。如果小美取出其中一件货物,则会把货物所在的一堆物品以取出的货物为界分成两堆,这样可以保证货物局部的顺序不变。
已知货物最初是按 1~n 的顺序堆放的,每件货物的重量为 w[i] ,小美会根据单据依次不放回的取出货物。请问根据上述操作,小美每取出一件货物之后,重量和最大的一堆货物重量是多少?
格式:
输入:
- 输入第一行包含一个正整数 n ,表示货物的数量。
- 输入第二行包含 n 个正整数,表示 1~n 号货物的重量 w[i] 。
- 输入第三行有 n 个数,表示小美按顺序取出的货物的编号,也就是一个 1~n 的全排列。
输出:
- 输出包含 n 行,每行一个整数,表示每取出一件货物以后,对于重量和最大的一堆货物,其重量和为多少。
示例:
输入:
5
3 2 4 4 5
4 3 5 2 1
输出:
9
5
5
3
0
解释:
原本的状态是 {{3,2,4,4,5}} ,取出 4 号货物后,得到 {{3,2,4},{5}} ,第一堆货物的和是 9 ,然后取出 3 号货物得到 {{3,2}{5}} ,此时第一堆和第二堆的和都是 5 ,以此类推。
提示:
1 <= n,m <= 50000
1 <= w[i] <= 100
解题思路
方法一:线段树
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50 | import sys
from itertools import accumulate
from typing import List
readline = sys.stdin.readline
def readint():
return int(readline().strip())
def readstr():
return readline().strip()
def readints():
return list(map(int, readline().strip().split()))
class Node:
def __init__(self, start: int, end: int, max: int):
self.start, self.end, self.max = start, end, max
self.left, self.right = None, None
def update(root: Node, k: int, presum: List[int]) -> None:
if root.start < k < root.end and not root.left and not root.right:
# left: [root->start+1,k-1], right: [k+1,root->end-1]
leftsum = (k - 1 >= 0 and presum[k - 1]) - (
root.start >= 0 and presum[root.start]
)
rightsum = (root.end >= 0 and presum[root.end - 1]) - (
k >= 0 and presum[k]
)
root.left = Node(root.start, k, leftsum)
root.right = Node(k, root.end, rightsum)
elif root.left and k < root.left.end:
update(root.left, k, presum)
else:
update(root.right, k, presum)
root.max = max(root.left.max, root.right.max)
n = readint()
presum = list(accumulate(readints()))
s = readints()
root = Node(-1, n, presum[n - 1])
for k in s:
update(root, k - 1, presum)
print(root.max)
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42 | #include <bits/stdc++.h>
using namespace std;
struct Node {
int start, end, max;
Node *left, *right;
Node(int start, int end, int max)
: start(start), end(end), max(max), left(nullptr), right(nullptr){};
};
void update(Node *root, int k, vector<int> &presum) {
if (root->start < k && k < root->end && !root->left && !root->right) {
// left: [root->start+1,k-1], right: [k+1,root->end-1]
int leftsum = (k - 1 >= 0 ? presum[k - 1] : 0) -
(root->start >= 0 ? presum[root->start] : 0);
int rightsum = (root->end - 1 >= 0 ? presum[root->end - 1] : 0) -
(k >= 0 ? presum[k] : 0);
root->left = new Node(root->start, k, leftsum);
root->right = new Node(k, root->end, rightsum);
} else if (root->left && k < root->left->end) {
update(root->left, k, presum);
} else {
update(root->right, k, presum);
}
root->max = max(root->left->max, root->right->max);
}
int main() {
ios_base::sync_with_stdio(false), cin.tie(nullptr);
int n, w, k, sum = 0;
cin >> n;
vector<int> presum(n);
for (int i = 0; i < n; ++i) {
cin >> w;
sum += w;
presum[i] = sum;
}
Node *root = new Node(-1, n, presum[n - 1]);
for (int i = 0; i < n; ++i) {
cin >> k;
update(root, k - 1, presum);
cout << root->max << '\n';
}
return 0;
}
|
但是最差情况下的复杂度是 \(\mathcal{O}(n^)\),如
| 50000
1 2 ... 50000
1 2 ... 50000
|
这组数据就会超时,只是题目数据比较弱所以能够通过。
方法二:反向添加+并查集
按照与取货相反的顺序,不断添加货物。
然后使用并查集,将添加的第 \(x\) 个货物与其左 \(x-1\) 右 \(x+1\) 两边的货物堆进行合并,合并的前提是:
- \(x-1\) 和 \(x+1\) 的下标合法
- \(x-1\) 和 \(x+1\) 已经存在(即已经添加过)
然后在添加和合并的过程中维护最大值即可。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56 | import sys
readline = sys.stdin.readline
def readint():
return int(readline().strip())
def readstr():
return readline().strip()
def readints():
return list(map(int, readline().strip().split()))
n = readint()
w = readints()
o = readints()
mx = 0
fa = list(range(n))
size = [1] * n
sum = [0] * n
ans = [0] * n
def find_set(x: int) -> int:
if x != fa[x]:
fa[x] = find_set(fa[x])
return fa[x]
def union_sets(x: int, y: int) -> None:
global mx
if 0 <= y < n:
if sum[y]:
x, y = find_set(x), find_set(y)
if size[x] > size[y]:
x, y = y, x
fa[x] = y
size[y] += size[x]
sum[y] += sum[x]
mx = max(mx, sum[y])
for i in range(n - 1, -1, -1):
ans[i] = mx
x = o[i] - 1
sum[x] = w[x]
mx = max(mx, sum[x])
union_sets(x, x - 1)
union_sets(x, x + 1)
print("\n".join(map(str, ans)))
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46 | #include <bits/stdc++.h>
using namespace std;
int main() {
ios_base::sync_with_stdio(false), cin.tie(nullptr);
int n, mx = 0;
cin >> n;
vector<int> w(n), o(n), ans(n);
vector<int> fa(n), size(n), sum(n);
for (int i = 0; i < n; ++i) {
cin >> w[i];
fa[i] = i, size[i] = 1, sum[i] = 0;
}
function<int(int)> find_set = [&](int x) {
return x == fa[x] ? x : (fa[x] = find_set(fa[x]));
};
function<void(int, int)> union_sets = [&](int x, int y) {
if (0 <= y && y < n) {
x = find_set(x), y = find_set(y);
if (sum[y]) {
if (size[x] > size[y]) {
swap(x, y);
}
fa[x] = y;
size[y] += size[x];
sum[y] += sum[x];
mx = max(mx, sum[y]);
}
}
};
for (int i = 0; i < n; ++i) {
cin >> o[i];
o[i] -= 1;
}
for (int i = n - 1; i >= 0; --i) {
ans[i] = mx;
int x = o[i];
sum[x] = w[x];
mx = max(mx, sum[x]);
union_sets(x, x - 1);
union_sets(x, x + 1);
}
for (int i = 0; i < n; ++i) {
cout << ans[i] << '\n';
}
return 0;
}
|