解:首先有个套路是一条边的权值是[两端点颜色不同]。这个用树剖直接维护,支持修改。
每次询问建虚树,查询虚树上每条边的权值。然后树形DP,用开店的方法,每个点链加链查。
1 #include2 3 #define forson(x, i) for(int i = e[x]; i; i = edge[i].nex) 4 5 typedef long long LL; 6 const int N = 100010; 7 8 struct Edge { 9 int nex, v; 10 LL len; 11 }edge[N << 1], EDGE[N]; int tp, TP; 12 13 int e[N], top[N], fa[N], son[N], siz[N], d[N], pos[N], id[N], num, val[N], n, imp2[N]; 14 int sum[N << 2], lc[N << 2], rc[N << 2], tag[N << 2]; 15 int imp[N], K, stk[N], Top, RT, Time, E[N], vis[N], use[N], DEEP[N]; 16 LL SIZ[N], ans[N], D[N]; 17 18 inline void add(int x, int y) { 19 tp++; 20 edge[tp].v = y; 21 edge[tp].nex = e[x]; 22 e[x] = tp; 23 return; 24 } 25 26 /// ------------------- tree 1 ------------------------- 27 28 void DFS_1(int x, int f) { /// get fa son siz d 29 fa[x] = f; 30 siz[x] = 1; 31 d[x] = d[f] + 1; 32 forson(x, i) { 33 int y = edge[i].v; 34 if(y == f) continue; 35 DFS_1(y, x); 36 siz[x] += siz[y]; 37 if(siz[y] > siz[son[x]]) { 38 son[x] = y; 39 } 40 } 41 return; 42 } 43 44 void DFS_2(int x, int f) { /// get top pos id 45 top[x] = f; 46 pos[x] = ++num; 47 id[num] = x; 48 if(son[x]) DFS_2(son[x], f); 49 forson(x, i) { 50 int y = edge[i].v; 51 if(y == fa[x] || y == son[x]) continue; 52 DFS_2(y, y); 53 } 54 return; 55 } 56 57 /// ------------------ seg 1 ---------------------- 58 59 #define ls (o << 1) 60 #define rs (o << 1 | 1) 61 62 inline void pushup(int o) { 63 lc[o] = lc[ls]; 64 rc[o] = rc[rs]; 65 sum[o] = sum[ls] + sum[rs] + (rc[ls] != lc[rs]); 66 return; 67 } 68 69 inline void pushdown(int o) { 70 if(tag[o] != -1) { 71 lc[ls] = rc[ls] = tag[ls] = tag[o]; 72 lc[rs] = rc[rs] = tag[rs] = tag[o]; 73 sum[ls] = sum[rs] = 0; 74 tag[o] = -1; 75 } 76 return; 77 } 78 79 #undef ls 80 #undef rs 81 82 void build(int l, int r, int o) { 83 if(l == r) { 84 lc[o] = rc[o] = val[id[r]]; 85 sum[o] = 0; 86 return; 87 } 88 int mid = (l + r) >> 1; 89 build(l, mid, o << 1); 90 build(mid + 1, r, o << 1 | 1); 91 pushup(o); 92 return; 93 } 94 95 void change(int L, int R, int v, int l, int r, int o) { 96 if(L <= l && r <= R) { 97 lc[o] = rc[o] = tag[o] = v; 98 sum[o] = 0; 99 return;100 }101 int mid = (l + r) >> 1;102 pushdown(o);103 if(L <= mid) change(L, R, v, l, mid, o << 1);104 if(mid < R) change(L, R, v, mid + 1, r, o << 1 | 1);105 pushup(o);106 return;107 }108 109 int ask(int p, int l, int r, int o) {110 if(l == r) return lc[o];111 int mid = (l + r) >> 1;112 pushdown(o);113 if(p <= mid) return ask(p, l, mid, o << 1);114 else return ask(p, mid + 1, r, o << 1 | 1);115 }116 117 int getSum(int L, int R, int l, int r, int o) {118 if(L <= l && r <= R) {119 return sum[o];120 }121 pushdown(o);122 int mid = (l + r) >> 1;123 if(R <= mid) return getSum(L, R, l, mid, o << 1);124 if(mid < L) return getSum(L, R, mid + 1, r, o << 1 | 1);125 return getSum(L, R, l, mid, o << 1) + getSum(L, R, mid + 1, r, o << 1 | 1) + (rc[o << 1] != lc[o << 1 | 1]);126 }127 128 inline int lca(int x, int y) {129 while(top[x] != top[y]) {130 if(d[top[x]] < d[top[y]])131 y = fa[top[y]];132 else133 x = fa[top[x]];134 }135 return d[x] < d[y] ? x : y;136 }137 138 inline int getLen(int x, int z) {139 //printf("getLen %d %d \n", x, z);140 int col = ask(pos[x], 1, n, 1), ans = 0;141 while(top[x] != top[z]) {142 ans += (col != ask(pos[x], 1, n, 1));143 ans += getSum(pos[top[x]], pos[x], 1, n, 1);144 //printf("x = %d top[x] = %d col = %d ans = %d \n", x, top[x], col, ans);145 col = ask(pos[top[x]], 1, n, 1);146 x = fa[top[x]];147 }148 ans += (col != ask(pos[x], 1, n, 1));149 //printf("%d != %d \n", col, ask(pos[x], 1, n, 1));150 ans += getSum(pos[z], pos[x], 1, n, 1);151 //printf("return ans = %d \n", ans);152 return ans;153 }154 155 inline void Change(int x, int y, int v) {156 while(top[x] != top[y]) {157 if(d[top[x]] > d[top[y]]) {158 change(pos[top[x]], pos[x], v, 1, n, 1);159 x = fa[top[x]];160 }161 else {162 change(pos[top[y]], pos[y], v, 1, n, 1);163 y = fa[top[y]];164 }165 }166 if(d[x] < d[y]) std::swap(x, y);167 change(pos[y], pos[x], v, 1, n, 1);168 return;169 }170 171 /// ------------------- tree 2 ----------------------172 173 inline void ADD(int x, int y) {174 TP++;175 EDGE[TP].v = y;176 EDGE[TP].len = getLen(y, x);177 //printf("getLen %d %d = %d \n", y, x, EDGE[TP].len);178 EDGE[TP].nex = E[x];179 E[x] = TP;180 return;181 }182 183 inline bool cmp(const int &a, const int &b) {184 return pos[a] < pos[b];185 }186 187 inline void work(int x) {188 if(vis[x] == Time) return;189 vis[x] = Time;190 D[x] = E[x] = 0;191 return;192 }193 194 inline void build_t() {195 TP = 0;196 memcpy(imp + 1, imp2 + 1, K * sizeof(int));197 std::sort(imp + 1, imp + K + 1, cmp);198 stk[Top = 1] = imp[1];199 work(imp[1]);200 for(int i = 2; i <= K; i++) {201 int x = imp[i], y = lca(x, stk[Top]);202 work(x); work(y);203 while(Top > 1 && d[y] <= d[stk[Top - 1]]) {204 ADD(stk[Top - 1], stk[Top]);205 Top--;206 }207 if(y != stk[Top]) {208 ADD(y, stk[Top]);209 stk[Top] = y;210 }211 stk[++Top] = x;212 }213 while(Top > 1) {214 ADD(stk[Top - 1], stk[Top]);215 Top--;216 }217 RT = stk[Top];218 return;219 }220 221 void dfs_1(int x) { /// DP 1222 SIZ[x] = (use[x] == Time);223 for(int i = E[x]; i; i = EDGE[i].nex) {224 int y = EDGE[i].v;225 dfs_1(y);226 SIZ[x] += SIZ[y];227 }228 return;229 }230 231 void dfs_2(int x) { /// DP 2232 if(use[x] == Time) {233 ans[x] = D[x];234 }235 for(int i = E[x]; i; i = EDGE[i].nex) {236 int y = EDGE[i].v;237 D[y] = D[x] + SIZ[y] * EDGE[i].len;238 DEEP[y] = DEEP[x] + EDGE[i].len;239 //printf("dfs_2 D %d = %lld * %lld = %lld \n", y, SIZ[y], EDGE[i].len, D[y]);240 dfs_2(y);241 }242 return;243 }244 245 inline void cal() {246 build_t();247 dfs_1(RT);248 DEEP[RT] = 0;249 dfs_2(RT);250 return;251 }252 253 int main() {254 memset(tag, -1, sizeof(tag));255 int q;256 scanf("%d%d", &n, &q);257 for(int i = 1; i <= n; i++) {258 scanf("%d", &val[i]);259 }260 for(int i = 1, x, y; i < n; i++) {261 scanf("%d%d", &x, &y);262 add(x, y); add(y, x);263 }264 DFS_1(1, 0);265 DFS_2(1, 1);266 build(1, n, 1);267 268 for(int i = 1, f, x, y, z; i <= q; i++) {269 scanf("%d%d", &f, &x);270 if(f == 1) {271 scanf("%d%d", &y, &z);272 Change(x, y, z);273 }274 else {275 Time++;276 K = x;277 for(int j = 1; j <= K; j++) {278 scanf("%d", &imp2[j]);279 use[imp2[j]] = Time;280 }281 cal();282 LL SUM = 0;283 for(int i = 1; i <= K; i++) {284 SUM += DEEP[imp2[i]];285 //printf("D %d = %lld \n", imp2[i], D[imp2[i]]);286 }287 //printf("SUM = %lld \n", SUM);288 for(int i = 1; i <= K; i++) {289 printf("%lld ", SUM + K * DEEP[imp2[i]] - 2 * ans[imp2[i]] + K);290 }291 puts("");292 }293 }294 return 0;295 }