【题意】给出一棵树,每个结点有个权值a[i],m个询问,每个询问(x,y,z),输出x和y路径上a[i]^z的最大值。a[i],z < 2^16
【分析】首先如果区间固定,那么可以直接把区间内的所有点插入trie(从高位插入,每次插入16位);然后直接trie中询问即可,但是这里区间会改变,其实可以想可持久化线段树一样的方法做,每个结点建一颗trie依赖父亲,然后求出lca之后便在查找的时候用下lca的距离公式:dis = dis[x]+dis[y]-2*dis[lca];吐槽下自己弱:找了半天bug竟然是lca线段树写搓了(本以为线段树很自信了,唉)。
【AC CODE】1809ms
#pragma comment(linker, "/STACK:1024000000,1024000000")
#include <cstdio>
#include <cstring>
#include <cctype>
#include <cmath>
#include <set>
#include <bitset>
//#include <unordered_set>
#include <queue>
#include <stack>
#include <vector>
#include <string>
#include <algorithm>
using namespace std;
typedef long long LL;
#define rep(i,a,n) for(int i = a; i < n; i++)
#define repe(i,a,n) for(int i = a; i <= n; i++)
#define per(i,n,a) for(int i = n; i >= a; i--)
#define clc(a,b) memset(a,b,sizeof(a))
const int INF = 0x3f3f3f3f, MAXN = 200000+10, MAXM = MAXN<<1, BIT = 17;
int head[MAXN],tol,nxt[MAXM],to[MAXM];
int a[MAXN], ch[MAXN*50][2],sum[MAXN*50],all,rt[MAXN];
int cnt, ft[MAXN],d[MAXN<<1],num[MAXN<<1],fa[MAXN];
void insert(int &u, int v, int p)
{
ch[++all][0] = ch[u][0],ch[all][1] = ch[u][1],sum[all] = sum[u]+1;u = all;
if(p < 0) return;
insert(ch[u][(bool)(v&(1<<p))],v,p-1);
}
int trie_query(int x, int y, int lca, int val)
{
int fac = rt[fa[lca]];
x = rt[x], y = rt[y], lca = rt[lca];
int ans = 0;
per(i,BIT,0)
{
bool v = val&(1<<i);
if(sum[ch[y][v^1]]+sum[ch[x][v^1]]-sum[ch[lca][v^1]]-sum[ch[fac][v^1]])
{
ans |= 1<<i;
x = ch[x][v^1],y = ch[y][v^1],lca = ch[lca][v^1],fac = ch[fac][v^1];
}
else if(sum[ch[y][v]]+sum[ch[x][v]]-sum[ch[lca][v]]-sum[ch[fac][v]])
x = ch[x][v],y = ch[y][v],lca = ch[lca][v],fac = ch[fac][v];
}
return ans;
}
inline void add_edge(int u, int v)
{
nxt[tol] = head[u], to[tol] = v;
head[u] = tol++;
}
void dfs(int u, int dep)
{
d[++cnt] = dep;num[cnt] = u;ft[u] = cnt;
rt[u] = rt[fa[u]];
insert(rt[u],a[u],BIT);
for(int i = head[u]; ~i; i = nxt[i])
{
int v = to[i];
if(v == fa[u]) continue;
fa[v] = u;
dfs(v,dep+1);
d[++cnt] = dep,num[cnt] = u;
}
}
int mi[MAXN<<2],minum[MAXN<<2];
inline int id(int x, int y){return x+y|x!=y;}
void bulid(int x, int y)
{
if(x == y)
{
mi[id(x,y)] = d[x];
minum[id(x,y)] = x;
return;
}
int m = (x+y)>>1;
bulid(x,m);
bulid(m+1,y);
int u = id(x,y), l = id(x,m),r = id(m+1,y);
if(mi[l] < mi[r]) mi[u] = mi[l],minum[u] = minum[l];
else mi[u] = mi[r],minum[u] = minum[r];
}
int query(int x, int y, int ql, int qr)
{
if(ql <= x && y <= qr) return minum[id(x,y)];
int m = (x+y)>>1,ans = 0;
if(ql <= m)
{
int p = query(x,m,ql,qr);
if(d[p] < d[ans]) ans = p;
}
if(qr > m)
{
int p = query(m+1,y,ql,qr);
if(d[p] < d[ans]) ans = p;
}
return ans;
}
int lca_query(int x, int y)
{
x = ft[x], y = ft[y];
if(x > y) swap(x,y);
return num[query(1,cnt,x,y)];
}
int main()
{
#ifdef SHY
freopen("d:\\1.txt", "r", stdin);
#endif
int n,m;
d[0] = INF;
while(~scanf("%d %d", &n, &m))
{
repe(i,1,n) scanf("%d", &a[i]);
tol = 0;
clc(head,-1);
rep(i,1,n)
{
int u,v;
scanf("%d %d", &u, &v);
add_edge(u,v);
add_edge(v,u);
}
cnt = all = 0;
dfs(1,0);
bulid(1,cnt);
while(m--)
{
int x,y,v;
scanf("%d %d %d", &x, &y, &v);
int lca = lca_query(x,y);
int ans = trie_query(x,y,lca,v);
printf("%d\n", ans);
}
}
return 0;
}