2 条题解

  • 0
    @ 2024-11-21 2:44:05

    小提示:

    1. 给定点 x,yx,y 求两点之间简单路径的长度公式:记 ttx,yx,y 的最近公共祖先,depidep_i 表示 ii 点的深度,则 length=depx+depy2×deptlength=dep_x+dep_y-2\times dep_t
    // 查询两点间的距离
    int getLength(int x, int y) {
        return dep[x] + dep[y] - 2 * dep[lca(x, y)];
    }
    
    1. 给定点 x,yx,y 求两点之间简单路径的点权和公式:记 ttx,yx,y 的最近公共祖先,wiw_i根节点到 ii 点的点权和(需要预处理),fatfa_t 表示 tt 的父节点,则 sum=wx+wywtwfatsum=w_x+w_y-w_t-w_{fa_t}
    // 查询两点路径的点权值和
    int getSum(int x, int y) {
        int t = lca(x, y);
        return w[x] + w[y] - w[t] - w[fa[t]];
    }
    
    1. 给定点 x,yx,y 求两点之间简单路径的边权和公式:记 ttx,yx,y 的最近公共祖先,wiw_i根节点到 ii 点的边权和(需要预处理),则 sum=wx+wy2×wtsum=w_x+w_y-2\times w_t
    // 查询两点路径的边权值和
    int getSum(int x, int y) {
        int t = lca(x, y);
        return w[x] + w[y] - 2 * w[t];
    }
    

    以下使用 树上前缀和+哈希

    同样是使用哈希维护出 nn 个排列,用公式2求两点之间的点权和,之后和哈希后的排列比较是否相同即可。

    #include <bits/stdc++.h>
    
    using u32 = unsigned;
    using i64 = long long;
    using u64 = unsigned long long;
    using u128 = unsigned __int128;
    
    const int base = 13331;
    
    template<class T>
    constexpr T power(T a, u64 b, T res = 1) {
        for (; b != 0; b /= 2, a *= a) {
            if (b & 1) {
                res *= a;
            }
        }
        return res;
    }
    
    template<u32 P>
    constexpr u32 mulMod(u32 a, u32 b) {
        return u64(a) * b % P;
    }
    
    template<u64 P>
    constexpr u64 mulMod(u64 a, u64 b) {
        u64 res = a * b - u64(1.L * a * b / P - 0.5L) * P;
        res %= P;
        return res;
    }
    
    constexpr i64 safeMod(i64 x, i64 m) {
        x %= m;
        if (x < 0) {
            x += m;
        }
        return x;
    }
    
    constexpr std::pair<i64, i64> invGcd(i64 a, i64 b) {
        a = safeMod(a, b);
        if (a == 0) {
            return {b, 0};
        }
        
        i64 s = b, t = a;
        i64 m0 = 0, m1 = 1;
    
        while (t) {
            i64 u = s / t;
            s -= t * u;
            m0 -= m1 * u;
            
            std::swap(s, t);
            std::swap(m0, m1);
        }
        
        if (m0 < 0) {
            m0 += b / s;
        }
        
        return {s, m0};
    }
    
    template<std::unsigned_integral U, U P>
    struct ModIntBase {
    public:
        constexpr ModIntBase() : x(0) {}
        template<std::unsigned_integral T>
        constexpr ModIntBase(T x_) : x(x_ % mod()) {}
        template<std::signed_integral T>
        constexpr ModIntBase(T x_) {
            using S = std::make_signed_t<U>;
            S v = x_;
            v %= S(mod());
            if (v < 0) {
                v += mod();
            }
            x = v;
        }
        
        constexpr static U mod() {
            return P;
        }
        
        constexpr U val() const {
            return x;
        }
        
        constexpr ModIntBase operator-() const {
            ModIntBase res;
            res.x = (x == 0 ? 0 : mod() - x);
            return res;
        }
        
        constexpr ModIntBase inv() const {
            return power(*this, mod() - 2);
        }
        
        constexpr ModIntBase &operator*=(const ModIntBase &rhs) & {
            x = mulMod<mod()>(x, rhs.val());
            return *this;
        }
        constexpr ModIntBase &operator+=(const ModIntBase &rhs) & {
            x += rhs.val();
            if (x >= mod()) {
                x -= mod();
            }
            return *this;
        }
        constexpr ModIntBase &operator-=(const ModIntBase &rhs) & {
            x -= rhs.val();
            if (x >= mod()) {
                x += mod();
            }
            return *this;
        }
        constexpr ModIntBase &operator/=(const ModIntBase &rhs) & {
            return *this *= rhs.inv();
        }
        
        friend constexpr ModIntBase operator*(ModIntBase lhs, const ModIntBase &rhs) {
            lhs *= rhs;
            return lhs;
        }
        friend constexpr ModIntBase operator+(ModIntBase lhs, const ModIntBase &rhs) {
            lhs += rhs;
            return lhs;
        }
        friend constexpr ModIntBase operator-(ModIntBase lhs, const ModIntBase &rhs) {
            lhs -= rhs;
            return lhs;
        }
        friend constexpr ModIntBase operator/(ModIntBase lhs, const ModIntBase &rhs) {
            lhs /= rhs;
            return lhs;
        }
        
        friend constexpr std::istream &operator>>(std::istream &is, ModIntBase &a) {
            i64 i;
            is >> i;
            a = i;
            return is;
        }
        friend constexpr std::ostream &operator<<(std::ostream &os, const ModIntBase &a) {
            return os << a.val();
        }
        
        friend constexpr std::strong_ordering operator<=>(ModIntBase lhs, ModIntBase rhs) {
            return lhs.val() <=> rhs.val();
        }
        
    private:
        U x;
    };
    
    template<u32 P>
    using ModInt = ModIntBase<u32, P>;
    template<u64 P>
    using ModInt64 = ModIntBase<u64, P>;
    
    struct Barrett {
    public:
        Barrett(u32 m_) : m(m_), im((u64)(-1) / m_ + 1) {}
    
        constexpr u32 mod() const {
            return m;
        }
    
        constexpr u32 mul(u32 a, u32 b) const {
            u64 z = a;
            z *= b;
            
            u64 x = u64((u128(z) * im) >> 64);
            
            u32 v = u32(z - x * m);
            if (m <= v) {
                v += m;
            }
            return v;
        }
    
    private:
        u32 m;
        u64 im;
    };
    
    template<u32 Id>
    struct DynModInt {
    public:
        constexpr DynModInt() : x(0) {}
        template<std::unsigned_integral T>
        constexpr DynModInt(T x_) : x(x_ % mod()) {}
        template<std::signed_integral T>
        constexpr DynModInt(T x_) {
            int v = x_;
            v %= int(mod());
            if (v < 0) {
                v += mod();
            }
            x = v;
        }
        
        constexpr static void setMod(u32 m) {
            bt = m;
        }
        
        static u32 mod() {
            return bt.mod();
        }
        
        constexpr u32 val() const {
            return x;
        }
        
        constexpr DynModInt operator-() const {
            DynModInt res;
            res.x = (x == 0 ? 0 : mod() - x);
            return res;
        }
        
        constexpr DynModInt inv() const {
            auto v = invGcd(x, mod());
            assert(v.first == 1);
            return v.second;
        }
        
        constexpr DynModInt &operator*=(const DynModInt &rhs) & {
            x = bt.mul(x, rhs.val());
            return *this;
        }
        constexpr DynModInt &operator+=(const DynModInt &rhs) & {
            x += rhs.val();
            if (x >= mod()) {
                x -= mod();
            }
            return *this;
        }
        constexpr DynModInt &operator-=(const DynModInt &rhs) & {
            x -= rhs.val();
            if (x >= mod()) {
                x += mod();
            }
            return *this;
        }
        constexpr DynModInt &operator/=(const DynModInt &rhs) & {
            return *this *= rhs.inv();
        }
        
        friend constexpr DynModInt operator*(DynModInt lhs, const DynModInt &rhs) {
            lhs *= rhs;
            return lhs;
        }
        friend constexpr DynModInt operator+(DynModInt lhs, const DynModInt &rhs) {
            lhs += rhs;
            return lhs;
        }
        friend constexpr DynModInt operator-(DynModInt lhs, const DynModInt &rhs) {
            lhs -= rhs;
            return lhs;
        }
        friend constexpr DynModInt operator/(DynModInt lhs, const DynModInt &rhs) {
            lhs /= rhs;
            return lhs;
        }
        
        friend constexpr std::istream &operator>>(std::istream &is, DynModInt &a) {
            i64 i;
            is >> i;
            a = i;
            return is;
        }
        friend constexpr std::ostream &operator<<(std::ostream &os, const DynModInt &a) {
            return os << a.val();
        }
        
        friend constexpr std::strong_ordering operator<=>(DynModInt lhs, DynModInt rhs) {
            return lhs.val() <=> rhs.val();
        }
    
        friend constexpr bool operator==(DynModInt lhs, DynModInt rhs) {
            return lhs.val() == rhs.val();
        }
        
    private:
        u32 x;
        static Barrett bt;
    };
    
    template<u32 Id>
    Barrett DynModInt<Id>::bt = 998244353;
    
    using Z = DynModInt<0>;
    
    template<class T>
    struct HLD {
        int n;
        std::vector<std::vector<int>> adj; // 图
        std::vector<int> siz, dep; // 大小、节点深度
        std::vector<int> fa; // 父节点数组
        std::vector<std::array<int, 17>> f; //倍增数组
        std::vector<T> w; // 点权数组
    
        HLD(int n_) {
            n = n_;
            adj.resize(n + 1);
            siz.resize(n + 1);
            dep.resize(n + 1);
            fa.resize(n + 1);
            f.resize(n + 1);
        }
        HLD(std::vector<T> &w): HLD(int(w.size())) {
            init(w);
        }
    
        void init(std::vector<T> &w) {
            this->w = w;
        }
    
        void add(int u, int v) {
            adj[u].push_back(v);
            adj[v].push_back(u);
        }
    
        std::function<void(int)> dfs1 = [&](int cur) -> void {
            dep[cur] = dep[fa[cur]] + 1;
            siz[cur] = 1;
            f[cur][0] = fa[cur];
            for (int j = 1; j <= 16; j++) {
                f[cur][j] = f[f[cur][j - 1]][j - 1];
            }
            for (auto v : adj[cur]) {
                if (v == fa[cur]) {
                    continue;
                }
                fa[v] = cur;
                w[v] += w[cur];
                dfs1(v);
                siz[cur] += siz[v];
            }
        };
        std::function<void(int)> dfs2 = [&](int cur) -> void {
            for (auto v : adj[cur]) {
                if (v == fa[cur]) {
                    continue;
                }
                dfs2(v);
            }
        };
    
        std::function<int(int, int)> lca = [&](int x, int y) -> int {
            if (dep[x] < dep[y]) {
                std::swap(x, y);
            }
            for (int j = 16; j >= 0; j--) {
                if (dep[f[x][j]] >= dep[y]) {
                    x = f[x][j];
                }
            }
            if (x == y) {
                return x;
            }
            for (int j = 16; j >= 0; j--) {
                if (f[x][j] != f[y][j]) {
                    x = f[x][j];
                    y = f[y][j];
                }
            }
            return f[x][0];
        };
    
        // 查询两点间的距离
        int calc(int x, int y) {
            return dep[x] + dep[y] - 2 * dep[lca(x, y)];
        }
    
        // 查询两点的权值和
        T getSum(int x, int y) {
            int t = lca(x, y);
            return w[x] + w[y] - w[t] - w[fa[t]];
        }
    
        void work(int root = 1) {
            dfs1(root);
            dfs2(root);
        }
    };
    
    
    int main() {
    	std::ios::sync_with_stdio(false);
    	std::cin.tie(nullptr);
    
    	int n, q;
    	std::cin >> n >> q;
    	std::vector<int> a(n + 1);
    	for (int i = 1; i <= n; i++) {
    		std::cin >> a[i];
    	}
    
    	std::vector<Z> p(n + 1), pre(n + 1);
    	p[0] = 1ULL;
    	for (int i = 1; i <= n; i++) {
    		p[i] = p[i - 1] * base;
    		pre[i] = pre[i - 1] + p[i];
    	}
    
        std::vector<Z> w(n + 1);
        for (int i = 1; i <= n; i++) {
            w[i] = p[a[i]];
        }
    
        HLD<Z> hld(w);
    
    	for (int i = 1; i < n; i++) {
    		int u, v;
    		std::cin >> u >> v;
            hld.add(u, v);
    	}
    
        hld.work();
    
    	while (q--) {
    		int u, v;
    		std::cin >> u >> v;
    		int len = hld.calc(u, v) + 1;
            Z ans = hld.getSum(u, v);
            std::cout << (ans == pre[len] ? "Yes" : "No") << "\n";
    	}
    	return 0;
    }
    

    信息

    ID
    144
    时间
    3000ms
    内存
    256MiB
    难度
    10
    标签
    递交数
    6
    已通过
    3
    上传者