平方分割のバケット法で区間の和を効率的に求める (UVa 12086 Potentiometers)

問題

{N} 個の可変抵抗が直列に繋がれている。

  • {x} 番目の抵抗を {r} に変更
  • 区間 {\displaystyle [ x,y ] } の合成抵抗を出力(直列だから足すだけ)

という指示が来るので順番に処理せよ。

平方分割

抵抗もクエリも多いので普通に足すだけだとTLEする。
以前にあり本のFenwick Treeをコピペして解いたけど、平方分割のバケット法というテクニックを知ったので、その練習のためにやり直した。

このスライドが参考になった。 http://www.slideshare.net/iwiwi/ss-3578491
FenwickTreeは更新が {O(\log N)} ,出力も {O(\log N)}、平方分割による {\sqrt N} 分木だと更新に {O(1)} ,出力に { O(2 \sqrt N + N / \sqrt N) = O(\sqrt N}) かかる。実際に掛かった時間はFenwickTreeは約0.21sec、平方分割は約0.43secくらいだった。更新が {O(1)} なのが優秀。

実装していてこれは賢いなーと思った。 バケット内の各要素に二分探索木を生やしたり、一般に {n} 次元でやったりもするらしい。(だからtemplate<typename T> にしてある。) そういえば弾幕STGの本で、当たり判定の部分で16分割くらいにして別々に処理していた気がする。

template<typename T>
struct sqrt_tree{
    vector<T> data, baqet;
    T sq;
    sqrt_tree(vector<T> const& v) :data(v){
        sq = sqrt(data.size());
        baqet.assign((data.size() + sq - 1) / sq, 0);
        rep(i, data.size()) baqet[i / sq] += data[i];
    }
    T sum(int l, int r){
        T res = 0;
        // 同じバケット内にある
        if (l / sq == r / sq){
            loop(i, l, r) res += data[i];
        } else {
            int e = r / sq;
            // 完全に被っているバケット
            loop(i, (l + sq - 1) / sq, e) res += baqet[i];
            // 中途半端な左側のバケット
            if (l%sq){
                int e = sq*(l / sq + 1);
                loop(i, l, e)res += data[i];
            }
            // 中途半端な右側のバケット
            if (r%sq){
                loop(i, r - r%sq, r)res += data[i];
            }
        }
        return res;
    }
    void update(int i, T x){
        baqet[i / sq] -= data[i];
        data[i] = x;
        baqet[i / sq] += x;
    }
};

int main(){
    int t = 1;
    int n;
    while (scanf("%d", &n), n){
        if (t != 1) puts("");
        printf("Case %d:\n", t);
        t++;
        vl v(n);
        rep(i, n) scanf("%lld", &v[i]);
        sqrt_tree<ll> s(v);
        while (1){
            char op[8];
            scanf("%s", op);
            if (op[0] == 'M'){
                int l, r;
                scanf("%d %d", &l, &r);
                l--;
                printf("%lld\n", s.sum(l, r));
            }
            else if (op[0] == 'S'){
                int i; ll x;
                scanf("%d %lld", &i, &x);
                s.update(i - 1, x);
            }
            else{
                goto end;
            }
        }
    end:;
        /*cout << "DEBUG" << endl;
        rep(i, v.size())loop(j, i + 1, v.size() + 1){
        cout << i << " " << j << " " << s.sum(i, j) << endl;
        }*/
    }

}