noshi91のメモ

データ構造のある風景

modint 構造体を使ってみませんか? (C++)

2019/04/01 実装例の訂正などを行いました。
2019/04/13 実行時に法が決まる問題についての説明を追加しました。

modint is 何

競技プログラミングの問題で、答えを 1000000007 (あるいは他の素数) で割った余りを求める問題は頻出です。これらの問題ではアルゴリズムの過程で繰り返し mod を取りますが、modint は普通の整数型などと同じ感覚で扱うだけで自動的に mod を取ってくれるというものです。

使用例

単純ですが、以下の問題を解いてみることにします。

a, b, c, d が与えられます。a * b + c - d mod M (M は定数) を出力してください。

modint なし

long long a, b, c, d;
cin >> a >> b >> c >> d;
cout << ((a * b % M + c) % M - d + M) % M;

modint なし、関数使用

long long a, b, c, d;
cin >> a >> b >> c >> d;
cout << sub(add(mul(a, b, M), c, M), d, M);

modint 使用

modint<M> a, b, c, d;
cin >> a.a >> b.a >> c.a >> d.a;
cout << (a * b + c - d).a;


エクサウィザーズ2019を modint を使用して回答しました。

D Modulo Operations
Submission #4794708 - ExaWizards 2019

E Black or White
Submission #4794711 - ExaWizards 2019

利点

  • 可読性の向上
    % M などにコードのロジック部分が覆い隠されなくなり、バグの低減などが見込めます。
  • 速度の向上
    加減算は条件分岐を用いることで剰余計算を行わない高速化が可能です。modint なら実装の内部でそのような高速化を自由に掛けることが出来ます。コード本体に直接高速化を入れると、可読性の著しい低下などが懸念されます。
  • mod 取り忘れがなくなる
    mod を取り忘れて値が正しい領域を外れることがなくなります。
  • 分数がデバッグしやすくなる
    using mint = modint などとしておけば、double に書き換えることでそのまま分数計算などのデバッグが可能です。(場合によっては少しだけ入出力を書き換えるかもしれませんが)
  • 他のライブラリに適用しやすい
    例えば累乗を計算するとき、modpow(a, b, M) といった関数を用意しなくとも、template を用いて一般化した pow 関数に modint をそのまま適用することで使用可能になります。他にも行列ライブラリなどに適用可能です。


欠点

  • ライブラリがコード全体に占める割合が増える
    コンパクトにまとめたい場合、使用しない operator などは冗長になります。
  • アルゴリズムの性質に依存した高速化が行えない
    64bit を超えないギリギリまで剰余を取らずにおく、0 にならないことが分かっているので符号反転で条件分岐を行わない、等の切り詰めた高速化は不可能になります。


実装例

自由にご使用ください

#include <cstdint>

template <std::uint_fast64_t Modulus> class modint {
  using u64 = std::uint_fast64_t;

public:
  u64 a;

  constexpr modint(const u64 x = 0) noexcept : a(x % Modulus) {}
  constexpr u64 &value() noexcept { return a; }
  constexpr const u64 &value() const noexcept { return a; }
  constexpr modint operator+(const modint rhs) const noexcept {
    return modint(*this) += rhs;
  }
  constexpr modint operator-(const modint rhs) const noexcept {
    return modint(*this) -= rhs;
  }
  constexpr modint operator*(const modint rhs) const noexcept {
    return modint(*this) *= rhs;
  }
  constexpr modint operator/(const modint rhs) const noexcept {
    return modint(*this) /= rhs;
  }
  constexpr modint &operator+=(const modint rhs) noexcept {
    a += rhs.a;
    if (a >= Modulus) {
      a -= Modulus;
    }
    return *this;
  }
  constexpr modint &operator-=(const modint rhs) noexcept {
    if (a < rhs.a) {
      a += Modulus;
    }
    a -= rhs.a;
    return *this;
  }
  constexpr modint &operator*=(const modint rhs) noexcept {
    a = a * rhs.a % Modulus;
    return *this;
  }
  constexpr modint &operator/=(modint rhs) noexcept {
    u64 exp = Modulus - 2;
    while (exp) {
      if (exp % 2) {
        *this *= rhs;
      }
      rhs *= rhs;
      exp /= 2;
    }
    return *this;
  }
};

利便性のため、a は public 指定しています。(本来は private にするのが良いと思います)
記事に掲載するにあたって、機能はかなり絞りました。足りないと思った部分は是非満足が行くようにカスタマイズしてください。

追加機能例
  • 各種演算子 (operator++, operator--, operator-(単項), operator== 等)
  • 逆元を返す関数 (inverse, operator~ 等)
  • std::cin, std::cout (あるいは自前の入出力) への対応
  • 負数に対応したコンストラク
  • 内部を 32 bit 整数で保持する
  • pow を組み込む
  • 0 除算が起きたときに assert や例外の送出などを行う
  • 法が大きすぎるときに static_assert を行う、あるいは __int128 を使う


実行時に法が決まるとき

法が入力で与えられる場合などは、上記の modint を使用することはできません。
そのような状況で使用できる実装は以下のようになります。

#include <cstdint>

class runtime_modint {
  using u64 = std::uint_fast64_t;

  static u64 &mod() {
    static u64 mod_ = 0;
    return mod_;
  }

public:
  u64 a;

  runtime_modint(const u64 x = 0) : a(x % get_mod()) {}
  u64 &value() noexcept { return a; }
  const u64 &value() const noexcept { return a; }
  runtime_modint operator+(const runtime_modint rhs) const {
    return runtime_modint(*this) += rhs;
  }
  runtime_modint operator-(const runtime_modint rhs) const {
    return runtime_modint(*this) -= rhs;
  }
  runtime_modint operator*(const runtime_modint rhs) const {
    return runtime_modint(*this) *= rhs;
  }
  runtime_modint operator/(const runtime_modint rhs) const {
    return runtime_modint(*this) /= rhs;
  }
  runtime_modint &operator+=(const runtime_modint rhs) {
    a += rhs.a;
    if (a >= get_mod()) {
      a -= get_mod();
    }
    return *this;
  }
  runtime_modint &operator-=(const runtime_modint rhs) {
    if (a < rhs.a) {
      a += get_mod();
    }
    a -= rhs.a;
    return *this;
  }
  runtime_modint &operator*=(const runtime_modint rhs) {
    a = a * rhs.a % get_mod();
    return *this;
  }
  runtime_modint &operator/=(runtime_modint rhs) {
    u64 exp = get_mod() - 2;
    while (exp) {
      if (exp % 2) {
        *this *= rhs;
      }
      rhs *= rhs;
      exp /= 2;
    }
    return *this;
  }

  static void set_mod(const u64 x) { mod() = x; }
  static u64 get_mod() { return mod(); }
};

定数に関する最適化や諸々の恩恵を受けられなくなりますので、使用できる限りは先に示した方の実装を使用した方がよいと考えています。


終わりに

提案や異論は広く募集しています。