2 条题解

  • 0
    @ 2024-11-21 2:44:05

    小提示:

    1. 给定点 x,yx,y 求两点之间简单路径的长度公式:记 ttx,yx,y 的最近公共祖先,depidep_i 表示 ii 点的深度,则 length=depx+depy2×deptlength=dep_x+dep_y-2\times dep_t
    // 查询两点间的距离
    int getLength(int x, int y) {
        return dep[x] + dep[y] - 2 * dep[lca(x, y)];
    }
    
    1. 给定点 x,yx,y 求两点之间简单路径的点权和公式:记 ttx,yx,y 的最近公共祖先,wiw_i根节点到 ii 点的点权和(需要预处理),fatfa_t 表示 tt 的父节点,则 sum=wx+wywtwfatsum=w_x+w_y-w_t-w_{fa_t}
    // 查询两点路径的点权值和
    int getSum(int x, int y) {
        int t = lca(x, y);
        return w[x] + w[y] - w[t] - w[fa[t]];
    }
    
    1. 给定点 x,yx,y 求两点之间简单路径的边权和公式:记 ttx,yx,y 的最近公共祖先,wiw_i根节点到 ii 点的边权和(需要预处理),则 sum=wx+wy2×wtsum=w_x+w_y-2\times w_t
    // 查询两点路径的边权值和
    int getSum(int x, int y) {
        int t = lca(x, y);
        return w[x] + w[y] - 2 * w[t];
    }
    

    以下使用 树上前缀和+哈希

    同样是使用哈希维护出 nn 个排列,用公式2求两点之间的点权和,之后和哈希后的排列比较是否相同即可。

    #include <bits/stdc++.h>
    
    using u32 = unsigned;
    using i64 = long long;
    using u64 = unsigned long long;
    using u128 = unsigned __int128;
    
    const int base = 13331;
    
    template<class T>
    constexpr T power(T a, u64 b, T res = 1) {
        for (; b != 0; b /= 2, a *= a) {
            if (b & 1) {
                res *= a;
            }
        }
        return res;
    }
    
    template<u32 P>
    constexpr u32 mulMod(u32 a, u32 b) {
        return u64(a) * b % P;
    }
    
    template<u64 P>
    constexpr u64 mulMod(u64 a, u64 b) {
        u64 res = a * b - u64(1.L * a * b / P - 0.5L) * P;
        res %= P;
        return res;
    }
    
    constexpr i64 safeMod(i64 x, i64 m) {
        x %= m;
        if (x < 0) {
            x += m;
        }
        return x;
    }
    
    constexpr std::pair<i64, i64> invGcd(i64 a, i64 b) {
        a = safeMod(a, b);
        if (a == 0) {
            return {b, 0};
        }
        
        i64 s = b, t = a;
        i64 m0 = 0, m1 = 1;
    
        while (t) {
            i64 u = s / t;
            s -= t * u;
            m0 -= m1 * u;
            
            std::swap(s, t);
            std::swap(m0, m1);
        }
        
        if (m0 < 0) {
            m0 += b / s;
        }
        
        return {s, m0};
    }
    
    template<std::unsigned_integral U, U P>
    struct ModIntBase {
    public:
        constexpr ModIntBase() : x(0) {}
        template<std::unsigned_integral T>
        constexpr ModIntBase(T x_) : x(x_ % mod()) {}
        template<std::signed_integral T>
        constexpr ModIntBase(T x_) {
            using S = std::make_signed_t<U>;
            S v = x_;
            v %= S(mod());
            if (v < 0) {
                v += mod();
            }
            x = v;
        }
        
        constexpr static U mod() {
            return P;
        }
        
        constexpr U val() const {
            return x;
        }
        
        constexpr ModIntBase operator-() const {
            ModIntBase res;
            res.x = (x == 0 ? 0 : mod() - x);
            return res;
        }
        
        constexpr ModIntBase inv() const {
            return power(*this, mod() - 2);
        }
        
        constexpr ModIntBase &operator*=(const ModIntBase &rhs) & {
            x = mulMod<mod()>(x, rhs.val());
            return *this;
        }
        constexpr ModIntBase &operator+=(const ModIntBase &rhs) & {
            x += rhs.val();
            if (x >= mod()) {
                x -= mod();
            }
            return *this;
        }
        constexpr ModIntBase &operator-=(const ModIntBase &rhs) & {
            x -= rhs.val();
            if (x >= mod()) {
                x += mod();
            }
            return *this;
        }
        constexpr ModIntBase &operator/=(const ModIntBase &rhs) & {
            return *this *= rhs.inv();
        }
        
        friend constexpr ModIntBase operator*(ModIntBase lhs, const ModIntBase &rhs) {
            lhs *= rhs;
            return lhs;
        }
        friend constexpr ModIntBase operator+(ModIntBase lhs, const ModIntBase &rhs) {
            lhs += rhs;
            return lhs;
        }
        friend constexpr ModIntBase operator-(ModIntBase lhs, const ModIntBase &rhs) {
            lhs -= rhs;
            return lhs;
        }
        friend constexpr ModIntBase operator/(ModIntBase lhs, const ModIntBase &rhs) {
            lhs /= rhs;
            return lhs;
        }
        
        friend constexpr std::istream &operator>>(std::istream &is, ModIntBase &a) {
            i64 i;
            is >> i;
            a = i;
            return is;
        }
        friend constexpr std::ostream &operator<<(std::ostream &os, const ModIntBase &a) {
            return os << a.val();
        }
        
        friend constexpr std::strong_ordering operator<=>(ModIntBase lhs, ModIntBase rhs) {
            return lhs.val() <=> rhs.val();
        }
        
    private:
        U x;
    };
    
    template<u32 P>
    using ModInt = ModIntBase<u32, P>;
    template<u64 P>
    using ModInt64 = ModIntBase<u64, P>;
    
    struct Barrett {
    public:
        Barrett(u32 m_) : m(m_), im((u64)(-1) / m_ + 1) {}
    
        constexpr u32 mod() const {
            return m;
        }
    
        constexpr u32 mul(u32 a, u32 b) const {
            u64 z = a;
            z *= b;
            
            u64 x = u64((u128(z) * im) >> 64);
            
            u32 v = u32(z - x * m);
            if (m <= v) {
                v += m;
            }
            return v;
        }
    
    private:
        u32 m;
        u64 im;
    };
    
    template<u32 Id>
    struct DynModInt {
    public:
        constexpr DynModInt() : x(0) {}
        template<std::unsigned_integral T>
        constexpr DynModInt(T x_) : x(x_ % mod()) {}
        template<std::signed_integral T>
        constexpr DynModInt(T x_) {
            int v = x_;
            v %= int(mod());
            if (v < 0) {
                v += mod();
            }
            x = v;
        }
        
        constexpr static void setMod(u32 m) {
            bt = m;
        }
        
        static u32 mod() {
            return bt.mod();
        }
        
        constexpr u32 val() const {
            return x;
        }
        
        constexpr DynModInt operator-() const {
            DynModInt res;
            res.x = (x == 0 ? 0 : mod() - x);
            return res;
        }
        
        constexpr DynModInt inv() const {
            auto v = invGcd(x, mod());
            assert(v.first == 1);
            return v.second;
        }
        
        constexpr DynModInt &operator*=(const DynModInt &rhs) & {
            x = bt.mul(x, rhs.val());
            return *this;
        }
        constexpr DynModInt &operator+=(const DynModInt &rhs) & {
            x += rhs.val();
            if (x >= mod()) {
                x -= mod();
            }
            return *this;
        }
        constexpr DynModInt &operator-=(const DynModInt &rhs) & {
            x -= rhs.val();
            if (x >= mod()) {
                x += mod();
            }
            return *this;
        }
        constexpr DynModInt &operator/=(const DynModInt &rhs) & {
            return *this *= rhs.inv();
        }
        
        friend constexpr DynModInt operator*(DynModInt lhs, const DynModInt &rhs) {
            lhs *= rhs;
            return lhs;
        }
        friend constexpr DynModInt operator+(DynModInt lhs, const DynModInt &rhs) {
            lhs += rhs;
            return lhs;
        }
        friend constexpr DynModInt operator-(DynModInt lhs, const DynModInt &rhs) {
            lhs -= rhs;
            return lhs;
        }
        friend constexpr DynModInt operator/(DynModInt lhs, const DynModInt &rhs) {
            lhs /= rhs;
            return lhs;
        }
        
        friend constexpr std::istream &operator>>(std::istream &is, DynModInt &a) {
            i64 i;
            is >> i;
            a = i;
            return is;
        }
        friend constexpr std::ostream &operator<<(std::ostream &os, const DynModInt &a) {
            return os << a.val();
        }
        
        friend constexpr std::strong_ordering operator<=>(DynModInt lhs, DynModInt rhs) {
            return lhs.val() <=> rhs.val();
        }
    
        friend constexpr bool operator==(DynModInt lhs, DynModInt rhs) {
            return lhs.val() == rhs.val();
        }
        
    private:
        u32 x;
        static Barrett bt;
    };
    
    template<u32 Id>
    Barrett DynModInt<Id>::bt = 998244353;
    
    using Z = DynModInt<0>;
    
    template<class T>
    struct HLD {
        int n;
        std::vector<std::vector<int>> adj; // 图
        std::vector<int> siz, dep; // 大小、节点深度
        std::vector<int> fa; // 父节点数组
        std::vector<std::array<int, 17>> f; //倍增数组
        std::vector<T> w; // 点权数组
    
        HLD(int n_) {
            n = n_;
            adj.resize(n + 1);
            siz.resize(n + 1);
            dep.resize(n + 1);
            fa.resize(n + 1);
            f.resize(n + 1);
        }
        HLD(std::vector<T> &w): HLD(int(w.size())) {
            init(w);
        }
    
        void init(std::vector<T> &w) {
            this->w = w;
        }
    
        void add(int u, int v) {
            adj[u].push_back(v);
            adj[v].push_back(u);
        }
    
        std::function<void(int)> dfs1 = [&](int cur) -> void {
            dep[cur] = dep[fa[cur]] + 1;
            siz[cur] = 1;
            f[cur][0] = fa[cur];
            for (int j = 1; j <= 16; j++) {
                f[cur][j] = f[f[cur][j - 1]][j - 1];
            }
            for (auto v : adj[cur]) {
                if (v == fa[cur]) {
                    continue;
                }
                fa[v] = cur;
                w[v] += w[cur];
                dfs1(v);
                siz[cur] += siz[v];
            }
        };
        std::function<void(int)> dfs2 = [&](int cur) -> void {
            for (auto v : adj[cur]) {
                if (v == fa[cur]) {
                    continue;
                }
                dfs2(v);
            }
        };
    
        std::function<int(int, int)> lca = [&](int x, int y) -> int {
            if (dep[x] < dep[y]) {
                std::swap(x, y);
            }
            for (int j = 16; j >= 0; j--) {
                if (dep[f[x][j]] >= dep[y]) {
                    x = f[x][j];
                }
            }
            if (x == y) {
                return x;
            }
            for (int j = 16; j >= 0; j--) {
                if (f[x][j] != f[y][j]) {
                    x = f[x][j];
                    y = f[y][j];
                }
            }
            return f[x][0];
        };
    
        // 查询两点间的距离
        int calc(int x, int y) {
            return dep[x] + dep[y] - 2 * dep[lca(x, y)];
        }
    
        // 查询两点的权值和
        T getSum(int x, int y) {
            int t = lca(x, y);
            return w[x] + w[y] - w[t] - w[fa[t]];
        }
    
        void work(int root = 1) {
            dfs1(root);
            dfs2(root);
        }
    };
    
    
    int main() {
    	std::ios::sync_with_stdio(false);
    	std::cin.tie(nullptr);
    
    	int n, q;
    	std::cin >> n >> q;
    	std::vector<int> a(n + 1);
    	for (int i = 1; i <= n; i++) {
    		std::cin >> a[i];
    	}
    
    	std::vector<Z> p(n + 1), pre(n + 1);
    	p[0] = 1ULL;
    	for (int i = 1; i <= n; i++) {
    		p[i] = p[i - 1] * base;
    		pre[i] = pre[i - 1] + p[i];
    	}
    
        std::vector<Z> w(n + 1);
        for (int i = 1; i <= n; i++) {
            w[i] = p[a[i]];
        }
    
        HLD<Z> hld(w);
    
    	for (int i = 1; i < n; i++) {
    		int u, v;
    		std::cin >> u >> v;
            hld.add(u, v);
    	}
    
        hld.work();
    
    	while (q--) {
    		int u, v;
    		std::cin >> u >> v;
    		int len = hld.calc(u, v) + 1;
            Z ans = hld.getSum(u, v);
            std::cout << (ans == pre[len] ? "Yes" : "No") << "\n";
    	}
    	return 0;
    }
    
    • 0
      @ 2024-11-19 18:05:18

      树链剖分 + 多重集合哈希 + 树状数组

      哈希维护出 nn 个排列,树链剖分和树状数组跑出来 u,vu,v 两点简单路径的和,比较一下即可。

      #include <bits/stdc++.h>
      
      using i64 = long long;
      using u32 = unsigned;
      using u64 = unsigned long long;
      
      template<typename T>
      constexpr T power(T a, u64 b) {
          T res {1};
          for (; b != 0; b /= 2, a *= a) {
              if (b % 2 == 1) {
                  res *= a;
              }
          }
          return res;
      }
      
      template<u32 P>
      constexpr u32 mulMod(u32 a, u32 b) {
          return 1ULL * a * b % P;
      }
      
      template<u64 P>
      constexpr u64 mulMod(u64 a, u64 b) {
          u64 res = a * b - u64(1.L * a * b / P - 0.5L) * P;
          res %= P;
          return res;
      }
      
      template<typename U, U P>
      // requires std::unsigned_integral<U>
      struct ModIntBase {
      public:
          constexpr ModIntBase() : x {0} {}
          
          template<typename T>
          // requires std::integral<T>
          constexpr ModIntBase(T x_) : x {norm(x_ % T {P})} {}
          
          constexpr static U norm(U x) {
              if ((x >> (8 * sizeof(U) - 1) & 1) == 1) {
                  x += P;
              }
              if (x >= P) {
                  x -= P;
              }
              return x;
          }
          
          constexpr U val() const {
              return x;
          }
          
          constexpr ModIntBase operator-() const {
              ModIntBase res;
              res.x = norm(P - x);
              return res;
          }
          
          constexpr ModIntBase inv() const {
              return power(*this, P - 2);
          }
          
          constexpr ModIntBase &operator*=(const ModIntBase &rhs) & {
              x = mulMod<P>(x, rhs.val());
              return *this;
          }
          
          constexpr ModIntBase &operator+=(const ModIntBase &rhs) & {
              x = norm(x + rhs.x);
              return *this;
          }
          
          constexpr ModIntBase &operator-=(const ModIntBase &rhs) & {
              x = norm(x - rhs.x);
              return *this;
          }
          
          constexpr ModIntBase &operator/=(const ModIntBase &rhs) & {
              return *this *= rhs.inv();
          }
          
          friend constexpr ModIntBase operator*(ModIntBase lhs, const ModIntBase &rhs) {
              lhs *= rhs;
              return lhs;
          }
          
          friend constexpr ModIntBase operator+(ModIntBase lhs, const ModIntBase &rhs) {
              lhs += rhs;
              return lhs;
          }
          
          friend constexpr ModIntBase operator-(ModIntBase lhs, const ModIntBase &rhs) {
              lhs -= rhs;
              return lhs;
          }
          
          friend constexpr ModIntBase operator/(ModIntBase lhs, const ModIntBase &rhs) {
              lhs /= rhs;
              return lhs;
          }
          
          friend constexpr std::ostream &operator<<(std::ostream &os, const ModIntBase &a) {
              return os << a.val();
          }
          
          friend constexpr bool operator==(ModIntBase lhs, ModIntBase rhs) {
              return lhs.val() == rhs.val();
          }
          
          friend constexpr bool operator!=(ModIntBase lhs, ModIntBase rhs) {
              return lhs.val() != rhs.val();
          }
          
          friend constexpr bool operator<(ModIntBase lhs, ModIntBase rhs) {
              return lhs.val() < rhs.val();
          }
          
      private:
          U x;
      };
      
      template<u32 P>
      using ModInt = ModIntBase<u32, P>;
      
      template<u64 P>
      using ModInt64 = ModIntBase<u64, P>;
      
      constexpr u64 P = u64(1E18) + 9;
      constexpr u64 Base = 1145141;
      using Z = ModInt64<P>;
      
      template <class T>
      struct Hash1 {
          const int N;
          std::vector<T> p, S;
          Hash1(int _) : N(_), p(_ + 1), S(_ + 1) {}
          Hash1(std::string s) : Hash1(int(s.size())) {
              p[0] = 1ULL;
              for (int i = 1; i <= N; ++i) {
                  p[i] = p[i - 1] * Base;
              }
              for (int i = 1; i <= N; ++i) {
                  S[i] = S[i - 1] * Base + 1ULL * s[i - 1];
              }
          }
       
          T get_sub(int l, int r) {
              return S[r] - S[l - 1] * p[r - l + 1];
          }
      };
      
      constexpr int mod1 = 998244353, mod2 = 1E9 + 7;
      
      struct Hash2 {
          const int N, P1 = 13331, P2 = 1145141;
          std::vector<int> p1, p2, S1, S2;
          Hash2(int _) : N(_), p1(_ + 1), p2(_ + 1), S1(_ + 1), S2(_ + 1) {}
          Hash2(std::string s) : Hash2(int(s.size())) {
              p1[0] = p2[0] = 1;
              for (int i = 1; i <= N; ++i) {
                  p1[i] = 1LL * p1[i - 1] * P1 % mod1;
                  p2[i] = 1LL * p2[i - 1] * P2 % mod2;
              }
              for (int i = 1; i <= N; ++i) {
                  S1[i] = (1LL * S1[i - 1] * P1 % mod1 + s[i - 1]) % mod1;
                  S2[i] = (1LL * S2[i - 1] * P2 % mod2 + s[i - 1]) % mod2;
              }
          }
      
          std::pair<int, int> get_sub(int l, int r) {
              return {(S1[r] - 1LL * S1[l - 1] * p1[r - l + 1] % mod1 + mod1) % mod1, 
                  (S2[r] - 1LL * S2[l - 1] * p2[r - l + 1] % mod2 + mod2) % mod2
              };
          }
      };
      
      template <class T>
      struct Fenwick {
          const int N;
          std::vector<T> f;
          Fenwick(int _) : N(_), f(_ + 10) {}
      
          void add(int x, T v) {
              for (int i = x; i <= N; i += (i bitand -i)) {
                  f[i] += v;
              }
          }
      
          T query(int x) {
              if (x > N or x <= 0) {
                  return T(0ULL);
              }
      
              T res = 0ULL;
              for (int i = x; i; i -= (i bitand -i)) {
                  res += f[i];
              }
              return res;
          }
      
          T rangesum(int l, int r) {
              if (l > r) {
                  return T(0ULL);
              }
      
              return query(r) - query(l - 1);
          }
      };
      
      Fenwick<Z> fen(1E5);
      
      std::vector<std::vector<int>> e;
      
      struct HLD {
          int n, DFN = 0, rt;
          std::vector<int> L, R, dep, fa, top, sz, son, re;
          HLD(int _, int RT = 1) {
              init(_, RT);
          }
      
          void init(int _, int RT) {
              this->n = _;
              this->rt = RT;
              L.resize(n + 1);
              R.resize(n + 1);
              fa.resize(n + 1);
              re.resize(n + 1);
              sz.resize(n + 1);
              son.resize(n + 1);
              top.resize(n + 1);
              dep.resize(n + 1);
              dfs1(rt, 0);
              dfs2(rt, rt);
          }
      
          void dfs1(int u, int F) { // 重链剖分
              dep[u] = dep[F] + 1;
              sz[u] = 1;
              fa[u] = F;
              son[u] = -1;
              for (auto v : e[u]) {
                  if (v != F) {
                      dfs1(v, u);
                      sz[u] += sz[v];
                      if (son[u] == -1 || sz[v] > sz[son[u]]) {
                          son[u] = v;
                      }
                  }
              }
          }
      
          void dfs2(int u, int tp) { // top 为该重链的深度最小节点, dfn 顺序要按照重链遍历顺序
              L[u] = ++DFN;
              re[DFN] = u;
              top[u] = tp;
              if (son[u] == -1) {
                  R[u] = DFN;
                  return;
              }
      
              dfs2(son[u], tp);
              for (auto v : e[u]) {
                  if (v != son[u] && v != fa[u]) {
                      dfs2(v, v);
                  }
              }
              R[u] = DFN;
          }
      
          int lca(int u, int v) {
              while (top[u] != top[v]) {
                  if (dep[top[u]] > dep[top[v]]) {
                      u = fa[top[u]];
                  } else {
                      v = fa[top[v]];
                  }
              }
      
              return dep[u] < dep[v] ? u : v;
          }
      
          int Kth_ancestor(int u, int k) {
              int fu = top[u];
              while (dep[u] - dep[fu] < k && u != rt) {
                  k -= dep[u] - dep[fu] + 1;
                  u = fa[fu];
                  fu = top[u];
              }
      
              return re[L[u] - k];
          }
      
          Z rangesum(int u, int v, int z = 1) {
              int fu = top[u], fv = top[v];
              Z res = 0ULL;
              while (fu != fv) {
                  if (dep[fu] >= dep[fv]) {
                      // lseg.modify(1, L[fu], L[u], z);
                      res += fen.rangesum(L[fu], L[u]);
                      u = fa[fu];
                  } else {
                      // lseg.modify(1, L[fv], L[v], z);
                      res += fen.rangesum(L[fv], L[v]);
                      v = fa[fv];
                  }
                  fu = top[u];
                  fv = top[v];
              }
      
              if (dep[u] > dep[v]) {
                  // lseg.modify(1, L[v], L[u], z);
                  res += fen.rangesum(L[v], L[u]);
              } else {
                  // lseg.modify(1, L[u], L[v], z);
                  res += fen.rangesum(L[u], L[v]);
              }
              return res;
          }
      };
      
      int main() {
          std::ios::sync_with_stdio(false);
          std::cin.tie(nullptr);
      
          int n, q;
          std::cin >> n >> q;
      
          std::vector<int> a(n);
          for (int i = 0; i < n; ++i) {
              std::cin >> a[i];
          }
      
          e.resize(n + 1);
          for(int i = 1; i < n; ++i) {
              int u, v;
              std::cin >> u >> v;
              e[u].emplace_back(v);
              e[v].emplace_back(u);
          }
      
          std::vector<Z> p(n + 1), pref(n + 1);
          p[0] = 1ULL;
          for (int i = 1; i <= n; ++i) {
              p[i] = p[i - 1] * Base;
              pref[i] = p[i] + pref[i - 1];
          }
      
          HLD hld(n);
      
          for (int i = 1; i <= n; ++i) {
              int t = hld.re[i];
              fen.add(i, p[a[t - 1]]);
          }
      
          while (q--) {
              int u, v;
              std::cin >> u >> v;
      
              int Lca = hld.lca(u, v), len = hld.dep[u] + hld.dep[v] - 2 * hld.dep[Lca] + 1;
              Z res = hld.rangesum(u, v);
      
              std::cout << (res == pref[len] ? "Yes" : "No") << "\n";
          }
      
          return 0;
      }
      
      • 1

      信息

      ID
      144
      时间
      3000ms
      内存
      256MiB
      难度
      10
      标签
      递交数
      6
      已通过
      3
      上传者