2 条题解
-
0
小提示:
- 给定点 求两点之间简单路径的长度公式:记 为 的最近公共祖先, 表示 点的深度,则
// 查询两点间的距离 int getLength(int x, int y) { return dep[x] + dep[y] - 2 * dep[lca(x, y)]; }
- 给定点 求两点之间简单路径的点权和公式:记 为 的最近公共祖先, 为根节点到 点的点权和(需要预处理), 表示 的父节点,则
// 查询两点路径的点权值和 int getSum(int x, int y) { int t = lca(x, y); return w[x] + w[y] - w[t] - w[fa[t]]; }
- 给定点 求两点之间简单路径的边权和公式:记 为 的最近公共祖先, 为根节点到 点的边权和(需要预处理),则
// 查询两点路径的边权值和 int getSum(int x, int y) { int t = lca(x, y); return w[x] + w[y] - 2 * w[t]; }
以下使用 树上前缀和+哈希
同样是使用哈希维护出 个排列,用公式2求两点之间的点权和,之后和哈希后的排列比较是否相同即可。
#include <bits/stdc++.h> using u32 = unsigned; using i64 = long long; using u64 = unsigned long long; using u128 = unsigned __int128; const int base = 13331; template<class T> constexpr T power(T a, u64 b, T res = 1) { for (; b != 0; b /= 2, a *= a) { if (b & 1) { res *= a; } } return res; } template<u32 P> constexpr u32 mulMod(u32 a, u32 b) { return u64(a) * b % P; } template<u64 P> constexpr u64 mulMod(u64 a, u64 b) { u64 res = a * b - u64(1.L * a * b / P - 0.5L) * P; res %= P; return res; } constexpr i64 safeMod(i64 x, i64 m) { x %= m; if (x < 0) { x += m; } return x; } constexpr std::pair<i64, i64> invGcd(i64 a, i64 b) { a = safeMod(a, b); if (a == 0) { return {b, 0}; } i64 s = b, t = a; i64 m0 = 0, m1 = 1; while (t) { i64 u = s / t; s -= t * u; m0 -= m1 * u; std::swap(s, t); std::swap(m0, m1); } if (m0 < 0) { m0 += b / s; } return {s, m0}; } template<std::unsigned_integral U, U P> struct ModIntBase { public: constexpr ModIntBase() : x(0) {} template<std::unsigned_integral T> constexpr ModIntBase(T x_) : x(x_ % mod()) {} template<std::signed_integral T> constexpr ModIntBase(T x_) { using S = std::make_signed_t<U>; S v = x_; v %= S(mod()); if (v < 0) { v += mod(); } x = v; } constexpr static U mod() { return P; } constexpr U val() const { return x; } constexpr ModIntBase operator-() const { ModIntBase res; res.x = (x == 0 ? 0 : mod() - x); return res; } constexpr ModIntBase inv() const { return power(*this, mod() - 2); } constexpr ModIntBase &operator*=(const ModIntBase &rhs) & { x = mulMod<mod()>(x, rhs.val()); return *this; } constexpr ModIntBase &operator+=(const ModIntBase &rhs) & { x += rhs.val(); if (x >= mod()) { x -= mod(); } return *this; } constexpr ModIntBase &operator-=(const ModIntBase &rhs) & { x -= rhs.val(); if (x >= mod()) { x += mod(); } return *this; } constexpr ModIntBase &operator/=(const ModIntBase &rhs) & { return *this *= rhs.inv(); } friend constexpr ModIntBase operator*(ModIntBase lhs, const ModIntBase &rhs) { lhs *= rhs; return lhs; } friend constexpr ModIntBase operator+(ModIntBase lhs, const ModIntBase &rhs) { lhs += rhs; return lhs; } friend constexpr ModIntBase operator-(ModIntBase lhs, const ModIntBase &rhs) { lhs -= rhs; return lhs; } friend constexpr ModIntBase operator/(ModIntBase lhs, const ModIntBase &rhs) { lhs /= rhs; return lhs; } friend constexpr std::istream &operator>>(std::istream &is, ModIntBase &a) { i64 i; is >> i; a = i; return is; } friend constexpr std::ostream &operator<<(std::ostream &os, const ModIntBase &a) { return os << a.val(); } friend constexpr std::strong_ordering operator<=>(ModIntBase lhs, ModIntBase rhs) { return lhs.val() <=> rhs.val(); } private: U x; }; template<u32 P> using ModInt = ModIntBase<u32, P>; template<u64 P> using ModInt64 = ModIntBase<u64, P>; struct Barrett { public: Barrett(u32 m_) : m(m_), im((u64)(-1) / m_ + 1) {} constexpr u32 mod() const { return m; } constexpr u32 mul(u32 a, u32 b) const { u64 z = a; z *= b; u64 x = u64((u128(z) * im) >> 64); u32 v = u32(z - x * m); if (m <= v) { v += m; } return v; } private: u32 m; u64 im; }; template<u32 Id> struct DynModInt { public: constexpr DynModInt() : x(0) {} template<std::unsigned_integral T> constexpr DynModInt(T x_) : x(x_ % mod()) {} template<std::signed_integral T> constexpr DynModInt(T x_) { int v = x_; v %= int(mod()); if (v < 0) { v += mod(); } x = v; } constexpr static void setMod(u32 m) { bt = m; } static u32 mod() { return bt.mod(); } constexpr u32 val() const { return x; } constexpr DynModInt operator-() const { DynModInt res; res.x = (x == 0 ? 0 : mod() - x); return res; } constexpr DynModInt inv() const { auto v = invGcd(x, mod()); assert(v.first == 1); return v.second; } constexpr DynModInt &operator*=(const DynModInt &rhs) & { x = bt.mul(x, rhs.val()); return *this; } constexpr DynModInt &operator+=(const DynModInt &rhs) & { x += rhs.val(); if (x >= mod()) { x -= mod(); } return *this; } constexpr DynModInt &operator-=(const DynModInt &rhs) & { x -= rhs.val(); if (x >= mod()) { x += mod(); } return *this; } constexpr DynModInt &operator/=(const DynModInt &rhs) & { return *this *= rhs.inv(); } friend constexpr DynModInt operator*(DynModInt lhs, const DynModInt &rhs) { lhs *= rhs; return lhs; } friend constexpr DynModInt operator+(DynModInt lhs, const DynModInt &rhs) { lhs += rhs; return lhs; } friend constexpr DynModInt operator-(DynModInt lhs, const DynModInt &rhs) { lhs -= rhs; return lhs; } friend constexpr DynModInt operator/(DynModInt lhs, const DynModInt &rhs) { lhs /= rhs; return lhs; } friend constexpr std::istream &operator>>(std::istream &is, DynModInt &a) { i64 i; is >> i; a = i; return is; } friend constexpr std::ostream &operator<<(std::ostream &os, const DynModInt &a) { return os << a.val(); } friend constexpr std::strong_ordering operator<=>(DynModInt lhs, DynModInt rhs) { return lhs.val() <=> rhs.val(); } friend constexpr bool operator==(DynModInt lhs, DynModInt rhs) { return lhs.val() == rhs.val(); } private: u32 x; static Barrett bt; }; template<u32 Id> Barrett DynModInt<Id>::bt = 998244353; using Z = DynModInt<0>; template<class T> struct HLD { int n; std::vector<std::vector<int>> adj; // 图 std::vector<int> siz, dep; // 大小、节点深度 std::vector<int> fa; // 父节点数组 std::vector<std::array<int, 17>> f; //倍增数组 std::vector<T> w; // 点权数组 HLD(int n_) { n = n_; adj.resize(n + 1); siz.resize(n + 1); dep.resize(n + 1); fa.resize(n + 1); f.resize(n + 1); } HLD(std::vector<T> &w): HLD(int(w.size())) { init(w); } void init(std::vector<T> &w) { this->w = w; } void add(int u, int v) { adj[u].push_back(v); adj[v].push_back(u); } std::function<void(int)> dfs1 = [&](int cur) -> void { dep[cur] = dep[fa[cur]] + 1; siz[cur] = 1; f[cur][0] = fa[cur]; for (int j = 1; j <= 16; j++) { f[cur][j] = f[f[cur][j - 1]][j - 1]; } for (auto v : adj[cur]) { if (v == fa[cur]) { continue; } fa[v] = cur; w[v] += w[cur]; dfs1(v); siz[cur] += siz[v]; } }; std::function<void(int)> dfs2 = [&](int cur) -> void { for (auto v : adj[cur]) { if (v == fa[cur]) { continue; } dfs2(v); } }; std::function<int(int, int)> lca = [&](int x, int y) -> int { if (dep[x] < dep[y]) { std::swap(x, y); } for (int j = 16; j >= 0; j--) { if (dep[f[x][j]] >= dep[y]) { x = f[x][j]; } } if (x == y) { return x; } for (int j = 16; j >= 0; j--) { if (f[x][j] != f[y][j]) { x = f[x][j]; y = f[y][j]; } } return f[x][0]; }; // 查询两点间的距离 int calc(int x, int y) { return dep[x] + dep[y] - 2 * dep[lca(x, y)]; } // 查询两点的权值和 T getSum(int x, int y) { int t = lca(x, y); return w[x] + w[y] - w[t] - w[fa[t]]; } void work(int root = 1) { dfs1(root); dfs2(root); } }; int main() { std::ios::sync_with_stdio(false); std::cin.tie(nullptr); int n, q; std::cin >> n >> q; std::vector<int> a(n + 1); for (int i = 1; i <= n; i++) { std::cin >> a[i]; } std::vector<Z> p(n + 1), pre(n + 1); p[0] = 1ULL; for (int i = 1; i <= n; i++) { p[i] = p[i - 1] * base; pre[i] = pre[i - 1] + p[i]; } std::vector<Z> w(n + 1); for (int i = 1; i <= n; i++) { w[i] = p[a[i]]; } HLD<Z> hld(w); for (int i = 1; i < n; i++) { int u, v; std::cin >> u >> v; hld.add(u, v); } hld.work(); while (q--) { int u, v; std::cin >> u >> v; int len = hld.calc(u, v) + 1; Z ans = hld.getSum(u, v); std::cout << (ans == pre[len] ? "Yes" : "No") << "\n"; } return 0; }
信息
- ID
- 144
- 时间
- 3000ms
- 内存
- 256MiB
- 难度
- 10
- 标签
- 递交数
- 6
- 已通过
- 3
- 上传者