syakyo-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub Luzhiled/syakyo-library

:heavy_check_mark: test/atcoder/abc307_h.test.cpp

Depends on

Code

// verification-helper: PROBLEM https://atcoder.jp/contests/abc307/tasks/abc307_h
// verification-helper: PROBLEM https://atcoder.jp/contests/abc307/tasks/abc307_ex

#include "src/cpp-template/header/size-alias.hpp"
#include "src/math/modular-arithmetic/static-modint.hpp"
#include "src/string/wildcard-pattern-matching.hpp"

#include <algorithm>
#include <iostream>

namespace luz {

  void main_() {
    const i64 mod = 998244353;

    usize l, w;
    std::string s;

    std::cin >> l >> w >> s;
    s += std::string(w - 1, '.');
    s += s.substr(0, w - 1);

    std::string p;
    std::cin >> p;

    auto wpm = wildcard_pattern_matching(s.begin(), s.end(), p.begin(), p.end(), '_', mod);

    std::cout << std::count(wpm.begin(), wpm.end(), 1) << std::endl;
  }

} // namespace luz

int main() {
  luz::main_();
}
#line 1 "test/atcoder/abc307_h.test.cpp"
// verification-helper: PROBLEM https://atcoder.jp/contests/abc307/tasks/abc307_h
// verification-helper: PROBLEM https://atcoder.jp/contests/abc307/tasks/abc307_ex

#line 2 "src/cpp-template/header/size-alias.hpp"

#include <cstddef>

namespace luz {

  using isize = std::ptrdiff_t;
  using usize = std::size_t;

}
#line 2 "src/math/modular-arithmetic/static-modint.hpp"

#line 2 "src/cpp-template/header/int-alias.hpp"

#include <cstdint>

namespace luz {

  using i32 = std::int32_t;
  using i64 = std::int64_t;
  using u32 = std::uint32_t;
  using u64 = std::uint64_t;

}
#line 4 "src/math/modular-arithmetic/static-modint.hpp"

#include <cassert>
#include <iostream>

namespace luz {

  template < u32 mod >
  class StaticPrimeModInt {
    using modint = StaticPrimeModInt;

    u32 v;

   public:
    StaticPrimeModInt(): v(0) {}

    template < typename T >
    StaticPrimeModInt(T t) {
      i64 x = (i64)(t % (i64)mod);
      if (x < 0) x += mod;
      v = (u32)x;
    }

    u32 val() const {
      return v;
    }

    modint &operator+=(const modint &rhs) {
      v += rhs.v;
      if (v >= mod) v -= mod;
      return *this;
    }

    modint &operator-=(const modint &rhs) {
      v += mod - rhs.v; // <-
      if (v >= mod) v -= mod;
      return *this;
    }

    modint &operator*=(const modint &rhs) {
      v = (u32)(u64(1) * v * rhs.v % mod);
      return *this;
    }

    modint &operator/=(const modint &rhs) {
      *this *= rhs.inverse();
      return *this;
    }

    modint operator+() const {
      return *this;
    }

    modint operator-() const {
      return modint() - *this;
    }

    friend modint operator+(const modint &lhs, const modint &rhs) {
      return modint(lhs) += rhs;
    }

    friend modint operator-(const modint &lhs, const modint &rhs) {
      return modint(lhs) -= rhs;
    }

    friend modint operator*(const modint &lhs, const modint &rhs) {
      return modint(lhs) *= rhs;
    }

    friend modint operator/(const modint &lhs, const modint &rhs) {
      return modint(lhs) /= rhs;
    }

    friend bool operator==(const modint &lhs, const modint &rhs) {
      return lhs.v == rhs.v;
    }

    friend bool operator!=(const modint &lhs, const modint &rhs) {
      return lhs.v != rhs.v;
    }

    modint pow(i64 n) const {
      assert(0 <= n);
      modint x = *this, r = 1;
      while (n) {
        if (n & 1) r *= x;
        x *= x;
        n >>= 1;
      }
      return r;
    }

    modint inverse() const {
      assert(v != 0);
      return pow(mod - 2);
    }

    static constexpr u32 get_mod() {
      return mod;
    }

    friend std::ostream &operator<<(std::ostream &os,
                                    const modint &m) {
      os << m.val();
      return os;
    }
  };

  using modint998244353  = StaticPrimeModInt< 998244353 >;
  using modint1000000007 = StaticPrimeModInt< 1000000007 >;

} // namespace luz
#line 2 "src/string/wildcard-pattern-matching.hpp"

#line 2 "src/math/convolution/modint-convolution.hpp"

#line 2 "src/math/modular-arithmetic/mod-pow.hpp"

#line 4 "src/math/modular-arithmetic/mod-pow.hpp"

namespace luz {

  i64 mod_pow(i64 b, i64 e, i64 mod) {
    if (mod == 1) return 0;
    i64 ans{1};

    while (e) {
      if (e & 1) {
        ans = ans * b % mod;
      }
      b = b * b % mod;
      e /= 2;
    }

    return ans;
  }

}
#line 6 "src/math/convolution/modint-convolution.hpp"

#include <vector>

namespace luz {

  usize bw(u64 x) {
    if (x == 0) return 0;
    return 64 - __builtin_clzll(x);
  }

  void butterfly(std::vector< i64 > &vs, i64 mod) {
    constexpr i64 root = 62;
    usize n = vs.size(), h = bw(n) - 1;

    static std::vector< i64 > rt(2, 1);

    for (static usize k = 2, s = 2; k < n; k *= 2, s++) {
      rt.resize(n);
      i64 z[] = {1, mod_pow(root, mod >> s, mod)};
      for (usize i = k; i < 2 * k; i++) {
        rt[i] = rt[i / 2] * z[i & 1] % mod;
      }
    }

    std::vector< i64 > rev(n);

    for (usize i = 0; i < n; i++) {
      rev[i] = (rev[i / 2] | (i & 1) << h) / 2;
    }

    for (usize i = 0; i < n; i++) {
      if ((i64)i >= rev[i]) continue;
      std::swap(vs[i], vs[rev[i]]);
    }

    for (usize k = 1; k < n; k *= 2) {
      for (usize i = 0; i < n; i += 2 * k) {
        for (usize j = 0; j < k; j++) {
          i64 z = rt[j + k] * vs[i + j + k] % mod;
          i64 &vi = vs[i + j];

          vs[i + j + k] = vi - z + (z > vi ? mod : 0);
          vi += (vi + z >= mod ? z - mod : z);
        }
      }
    }
  }

  std::vector< i64 > modint_convolution(std::vector< i64 > f,
                                        std::vector< i64 > g,
                                        i64 mod) {
    usize n = f.size(), m = g.size();

    if (not n or not m) return {};
    
    usize s = 1 << bw(n + m - 2);
    i64 inv = mod_pow(s, mod - 2, mod);

    f.resize(s);
    g.resize(s);

    butterfly(f, mod);
    butterfly(g, mod);

    std::vector< i64 > res(s);
    for (isize i = 0; (usize)i < s; i++) {
      res[-i & (s - 1)] = f[i] * g[i] % mod * inv % mod;
    }
    butterfly(res, mod);

    res.resize(n + m - 1);
    return res;
  }

}
#line 6 "src/string/wildcard-pattern-matching.hpp"

#line 9 "src/string/wildcard-pattern-matching.hpp"

namespace luz {

  // [warning] false positive occur expect O(1/M)
  //           when values are randomized
  // [note] try to use multiple mods if necessary
  // [note] all of values are needed \in [1, mod)
  template < class Iter >
  std::vector< i32 > wildcard_pattern_matching(Iter f1, Iter l1,
                                               Iter f2, Iter l2,
                                               const i64 wildcard,
                                               i64 mod) {
    usize n = l1 - f1, m = l2 - f2;
    assert(m <= n);

    std::vector< i64 > as(n), bs(n), cs(n), ss(m), ts(m), us(m);

    for (Iter iter = f1; iter != l1; ++iter) {
      i64 x(*iter == wildcard ? 0 : *iter);
      i64 y(*iter == wildcard ? 0 : 1);
      usize i = iter - f1;
      as[i]   = y * x * x % mod;
      bs[i]   = y * x * (mod - 2) % mod;
      cs[i]   = y;
    }

    for (Iter iter = f2; iter != l2; ++iter) {
      i64 x(*iter == wildcard ? 0 : *iter);
      i64 y(*iter == wildcard ? 0 : 1);
      usize i = l2 - iter - 1;
      ss[i]   = y;
      ts[i]   = y * x;
      us[i]   = y * x * x % mod;
    }

    auto f = modint_convolution(as, ss, mod);
    auto g = modint_convolution(bs, ts, mod);
    auto h = modint_convolution(cs, us, mod);

    std::vector< i32 > result(n - m + 1);
    for (usize i = 0; i < result.size(); i++) {
      usize j = i + m - 1;
      i64 x((f[j] + g[j] + h[j]) % mod);
      if (x == 0) result[i] = 1;
    }

    return result;
  }

} // namespace luz
#line 7 "test/atcoder/abc307_h.test.cpp"

#include <algorithm>
#line 10 "test/atcoder/abc307_h.test.cpp"

namespace luz {

  void main_() {
    const i64 mod = 998244353;

    usize l, w;
    std::string s;

    std::cin >> l >> w >> s;
    s += std::string(w - 1, '.');
    s += s.substr(0, w - 1);

    std::string p;
    std::cin >> p;

    auto wpm = wildcard_pattern_matching(s.begin(), s.end(), p.begin(), p.end(), '_', mod);

    std::cout << std::count(wpm.begin(), wpm.end(), 1) << std::endl;
  }

} // namespace luz

int main() {
  luz::main_();
}
Back to top page