hdu-6096

字典树,扫描线+线段树


题意

给n个原串,q个查询,每个查询给出一个前缀和后缀,问原串中有多少个串的前缀和后缀满足该查询

  • n,q<=100000
  • ∑Si+Pi≤500000,∑Wi≤500000

题解

  • 假设原始的字符串 数组为A,首先将A中的每个字符串都进行翻转,得到字符串数组B,然后,将A和B按字典序排序。
  • 对于一个查询来说有一个前缀p和后缀s, 所有包含前缀p的字符串在A中是连续的,可通过二分求出该区间 设为[Lp,Rp],同样,所有包含后缀s的字符串在B中也是连续的,设为[Ls,Rs]
  • 接下来只需求解 有多少个字符串前缀是在[Lp,Rp] 同时后缀在[Ls,Rs]。对于每个字符串,假设在A中是第x个,在B中是第y个 ,那么我们只需要判断有多少个字符串 Lp<=x<=Rp 同时 Ls<=y<=Rs
  • 该问题转化为,有一些点(每个字符串相当于一个点,x是按前缀排完序的位置,y是按后缀排序),现给定一些矩形(每个查询可转化为 Lp<=x<=Rp,Ls<=y<=Rs),问矩形中包含多少个点,该问题是经典的矩形覆盖问题,线段树+扫描线 即可求出。
  • 按上述方法求出后,会存在重叠的问题 。如有一个字符串 aaa 查询如果为 aa aa的话也会查到 aaa。 那么我们需要进行去重,可直接对查询的前缀或者后缀做一个遍历,枚举重叠的长度,然后再哈希判断是否存在这样的原始字符串即可。
  • 时间复杂度O(nlog(n)+∣S∣)
  • 避免hash可以离线暴力在字典树上建线段树,查询在字典树上找到后缀对应节点查找前缀区间和。空间O(∣S∣log(n))。时间复杂度O(nlog(n)+∣S∣)
  • 也可以直接hash离线做。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
#include<bits/stdc++.h>
using namespace std;
typedef unsigned long long ull;
#define maxnode 500010
#define sigma_size 26
#define maxn 100010
#define X 997
struct Point{
int x,y;
Point(int x=0,int y=0):x(x),y(y) {}
}po[maxn];
struct Trie {
int ch[maxnode][sigma_size];
int val[maxnode];
int id[maxnode];
int sz;
void init() {
sz=1;
memset(id,-1,sizeof(id));
memset(ch[0],0,sizeof(ch[0]));
val[0]=0;
}
int idx(char c) {
return c-'a';
}
void insert(char *s,int n,int o) {
int u=0;
for(int i=0; i<n; i++) {
int c=idx(s[i]);
if(!ch[u][c]) {
memset(ch[sz],0,sizeof(ch[sz]));
val[sz]=0;
ch[u][c]=sz++;
}
u=ch[u][c];
val[u]++;
}
id[u]=o;
}
void getPoint(int u,int r,bool ty){
if(~id[u]){
r++;
if(!ty) po[id[u]].x=r;
else po[id[u]].y=r;
}
for(int i=0; i<26; i++) {
if(ch[u][i]){
getPoint(ch[u][i],r,ty);
r+=val[ch[u][i]];
}
}
}
Point findSeg(char *s){
int n=strlen(s),u=0;
int r=0;
for(int i=0; i<n; i++) {
int c=idx(s[i]);
if(!ch[u][c]) return Point();
for(int j=0;j<c;j++)
if(ch[u][j]) r+=val[ch[u][j]];
if(~id[u]) r++;
u=ch[u][c];
}
return Point(r+1,r+val[u]);
}
}pref,suff;
char p[maxn],s[maxn];
int ans[maxn];
struct Line{
int x,y1,y2,ty,id;
Line() {}
Line(int x,int y1,int y2,int ty,int id):x(x),y1(y1),y2(y2),ty(ty),id(id) {}
bool operator <(const Line &rhs)const{
return x<rhs.x||x==rhs.x&&y2<rhs.y2;
}
}L[maxn*3];
struct SegTree{
int sumv[maxn*4];
void init(){
memset(sumv,0,sizeof(sumv));
}
void update(int o,int L,int R,int p){
if(L==R) sumv[o]++;
else{
int M=(L+R)/2;
if(p<=M) update(o*2,L,M,p);
else update(o*2+1,M+1,R,p);
sumv[o]++;
}
}
int query(int o,int L,int R,int qL,int qR){
if(qL<=L&&R<=qR) return sumv[o];
int M=(L+R)/2,ans=0;
if(qL<=M) ans+=query(o*2,L,M,qL,qR);
if(qR>M) ans+=query(o*2+1,M+1,R,qL,qR);
return ans;
}
}T;
set<ull> st;
ull xp[maxn];
int main() {
// freopen("1001.in","r",stdin);
// freopen("ans.txt","w",stdout);
int ca;
xp[0]=1;
for(int i=1;i<maxn;i++) xp[i]=xp[i-1]*X;
scanf("%d",&ca);
while(ca--) {
int n,q;
memset(ans,0,sizeof(ans));
st.clear();
scanf("%d%d",&n,&q);
pref.init();suff.init();
for(int i=0; i<n; i++) {
scanf("%s",s);
int len=strlen(s);
ull h=0;
for(int j=0;j<len;j++){
h=h*X+s[j]-'0';
}
st.insert(h);
pref.insert(s,len,i);
reverse(s,s+len);
suff.insert(s,len,i);
}
pref.getPoint(0,0,0);
suff.getPoint(0,0,1);
for(int i=0;i<n;i++)
L[i]=Line(po[i].x,po[i].y,0,0,0);
Point e;
int k=n,px,py,sx,sy;
for(int te=0;te<q;te++){
scanf("%s%s",p,s);
e=pref.findSeg(p);
if(e.x==0) continue;
px=e.x;py=e.y;
int l2=strlen(s);
reverse(s,s+l2);
e=suff.findSeg(s);
if(e.x==0) continue;
int l1=strlen(p);
ull h,h1=0,h2=0;
reverse(s,s+l2);
for(int j=0;j<l1;j++) h1=h1*X+p[j]-'0';
for(int j=0;j<l2;j++) h2=h2*X+s[j]-'0';
for(int i=0;i<min(l1,l2);i++){
if(p[l1-i-1]!=s[i]) break;
h1-=(s[i]-'0')*xp[i];
h=h1*xp[l2-i-1]+h2;
if(st.count(h)) ans[te]--;
}
sx=e.x;sy=e.y;
L[k++]=Line(px-1,sx,sy,-1,te);
L[k++]=Line(py,sx,sy,1,te);
}
sort(L,L+k);
T.init();
for(int i=0;i<k;i++){
if(!L[i].y2) T.update(1,1,n,L[i].y1);
else ans[L[i].id]+=L[i].ty*T.query(1,1,n,L[i].y1,L[i].y2);
}
for(int i=0;i<q;i++)
printf("%d\n",ans[i]);
}
return 0;
}