Submission #4143620
Source Code Expand
#include <bits/stdc++.h>
using namespace std;
#ifdef USE_GRAPHVIZ
#include <boost/graph/adjacency_list.hpp>
#include <boost/graph/graphviz.hpp>
#endif
enum {
NOTFOUND = 0xFFFFFFFFFFFFFFFFLLU
};
uint64_t NODE_NO = 0;
class BitVectorNode {
public:
uint64_t no; // node番号
// internal nodeのときに使用
uint64_t num; // 左の子の部分木のもつbitの数
uint64_t ones; // 左の子の部分木のもつ1の数
BitVectorNode *left;
BitVectorNode *right;
int64_t balance; // 右の子の高さ - 左の子の高さ.+なら右の子の方が高い,-なら左の子の方が高い
// leafのときに使用
uint64_t bits; // bit
uint64_t bits_size; // bitのサイズ(debug用)
bool is_leaf;
BitVectorNode() : no(NODE_NO++), num(0), ones(0), bits(0), bits_size(0), is_leaf(false), left(nullptr), right(nullptr), balance(0) {}
BitVectorNode(uint64_t bits, uint64_t bits_size, bool is_leaf) : no(NODE_NO++), num(0), ones(0), bits(bits), bits_size(bits_size), is_leaf(is_leaf), left(nullptr), right(nullptr), balance(0) {}
bool is_valid_node() {
if (is_leaf) {
if (num != 0) { return false; }
if (ones != 0) { return false; }
if (left != nullptr) { return false; }
if (right != nullptr) { return false; }
}
else {
if (num == 0) { return false; }
if (left == nullptr) { return false; }
if (right == nullptr) { return false; }
if (bits != 0) { return false; }
if (bits_size != 0) { return false; }
if (ones > num) {return false; }
}
return true;
}
std::string info() {
std::string str = "No:" + std::to_string(this->no) + "\n";
if (is_leaf) {
str += "size:" + std::to_string(this->bits_size) + "\n";
for (int i = 0; i < bits_size; ++i) {
str += std::to_string((bits >> (uint64_t)i) & (uint64_t)1);
}
}
else {
str += "num:" + std::to_string(this->num) + " ones:" + std::to_string(this->ones) + "\n";
}
return str;
}
};
class DynamicBitVector {
public:
BitVectorNode *root;
uint64_t size; // 全体のbitの数
uint64_t num_one; // 全体の1の数
const uint64_t bits_size_limit = 32; // 各ノードが管理するbitの長さ制限.2 * bits_size_limit以上になったらノードを分割し, 1/2*bits_size_limit以下になったらノードを結合する
DynamicBitVector(): root(nullptr), size(0), num_one(0) {}
DynamicBitVector(std::vector<uint64_t> &v): root(nullptr), size(0), num_one(0) {
if (v.size() == 0) {
return;
}
std::deque<std::pair<BitVectorNode*, uint64_t>> leaves;
for (int i = 0; i < v.size(); i += this->bits_size_limit) {
uint64_t bits = 0;
const uint64_t bits_size = std::min(this->bits_size_limit, v.size() - i);
for (int j = 0; j < bits_size; ++j) {
assert(v[i + j] == 0 or v[i + j] == 1);
if (v[i + j] == 1) {
bits |= (uint64_t)1 << j;
}
}
leaves.emplace_back(std::make_pair(new BitVectorNode(bits, bits_size, true), bits_size));
}
std::deque<std::tuple<BitVectorNode*, uint64_t, uint64_t, uint64_t>> nodes; // (node, 全体のbit数, 全体の1の数, 高さ)
while (not leaves.empty()) {
const auto node = leaves.front().first;
const auto bits_size = leaves.front().second;
leaves.pop_front();
nodes.emplace_back(std::make_tuple(node, bits_size, popCount(node->bits), 0));
}
while (nodes.size() > 1) {
std::deque<std::tuple<BitVectorNode*, uint64_t, uint64_t, uint64_t>> next_nodes;
while (not nodes.empty()) {
// あまりがでたときは,最後に作った中間ノードと結合させるためにnodesに戻す
if (nodes.size() == 1) {
nodes.push_front(next_nodes.back());
next_nodes.pop_back();
}
BitVectorNode *left_node;
uint64_t left_num, left_ones, left_height;
std::tie(left_node, left_num, left_ones, left_height) = nodes.front(); nodes.pop_front();
BitVectorNode *right_node;
uint64_t right_num, right_ones, right_height;
std::tie(right_node, right_num, right_ones, right_height) = nodes.front(); nodes.pop_front();
const auto internal_node = new BitVectorNode(0, 0, false);
internal_node->num = left_num;
internal_node->ones = left_ones;
internal_node->left = left_node;
internal_node->right = right_node;
internal_node->balance = right_height - left_height;
next_nodes.emplace_back(std::make_tuple(internal_node, left_num + right_num, left_ones + right_ones, std::max(left_height, right_height) + 1));
}
nodes = next_nodes;
}
uint64_t height;
std::tie(this->root, this->size, this->num_one, height) = nodes.front(); nodes.pop_front();
assert(this->size == v.size());
}
// B[pos]
uint64_t access(uint64_t pos) {
assert(pos < this->size);
return access(this->root, pos);
}
// B[0, pos)にある指定されたbitの数
uint64_t rank(uint64_t bit, uint64_t pos) {
assert(bit == 0 or bit == 1);
assert(pos <= this->size);
if (bit) {
return rank1(this->root, pos, 0);
}
else {
return pos - rank1(this->root, pos, 0);
}
}
// rank番目の指定されたbitの位置 + 1(rankは1-origin)
uint64_t select(uint64_t bit, uint64_t rank) {
assert(bit == 0 or bit == 1);
assert(rank > 0);
if (bit == 0 and rank > this->size - this-> num_one) { return NOTFOUND; }
if (bit == 1 and rank > this-> num_one) { return NOTFOUND; }
if (bit) {
return select1(this->root, 0, rank);
}
else {
return select0(this->root, 0, rank);
}
}
// posにbitを挿入する
void insert(uint64_t pos, uint64_t bit) {
assert(bit == 0 or bit == 1);
assert(pos <= this->size); // 現在もってるbitsの末尾には挿入できる
if (root == nullptr) {
root = new BitVectorNode(bit, 1, true);
} else {
insert(this->root, nullptr, bit, pos, this->size);
}
this->size++;
this->num_one += (bit == 1);
}
// 末尾にbitを追加する
void push_back(uint64_t bit) {
assert(bit == 0 or bit == 1);
this->insert(this->size, bit);
}
// posを削除する
void erase(uint64_t pos) {
assert(pos < this->size);
uint64_t bit = this->remove(this->root, nullptr, pos, this->size, 0, true).first;
this->size--;
this->num_one -= (bit == 1);
}
// posにbitをセットする
void update(uint64_t pos, uint64_t bit) {
assert(bit == 0 or bit == 1);
assert(pos < this->size);
if (bit == 1) {
this->bitset(pos);
}
else {
this->bitclear(pos);
}
}
// posのbitを1にする
void bitset(uint64_t pos) {
assert(pos < this->size);
bool flip = bitset(this->root, pos);
this->num_one += flip;
}
// posのbitを0にする
void bitclear(uint64_t pos) {
assert(pos < this->size);
bool flip = bitclear(this->root, pos);
this->num_one -= flip;
}
// dotファイルを作成する(debug用)
void graphviz(const std::string &file_path) {
#ifdef USE_GRAPHVIZ
boost::adjacency_list<> graph;
std::vector<std::string> labels;
auto root = boost::add_vertex(graph);
labels.emplace_back(this->root->info());
std::queue<std::pair<Node*, boost::adjacency_list<>::vertex_descriptor>> que;
que.emplace(std::make_pair(this->root, root));
while (not que.empty()) {
Node *node = que.front().first;
auto parent = que.front().second;
que.pop();
if (not node->is_leaf) {
// left
auto left = boost::add_vertex(graph);
boost::add_edge(parent, left, graph);
labels.emplace_back("L\n" + node->left->info());
que.emplace(std::make_pair(node->left, left));
// right
auto right = boost::add_vertex(graph);
boost::add_edge(parent, right, graph);
labels.emplace_back("R\n" + node->right->info());
que.emplace(std::make_pair(node->right, right));
}
}
std::ofstream file(file_path);
boost::write_graphviz(file, graph, boost::make_label_writer(&labels[0]));
#else
std::cerr << "please define USE_GRAPHVIZ" << std::endl;
#endif
}
// 木の状態が正しいかチェックする(debug用)
// 各ノードのbalanceの値が正しいか.AVL木の制約を守っているかのチェック
bool is_valid_tree(bool verbose) {
if (this->root == nullptr) {
if (this->size == 0 and this->num_one == 0) {
return true;
}
std::cerr << "root is nullptr but size is " << this->size << " and num_one is " << this->num_one << std::endl;
return false;
}
std::map<uint64_t, uint64_t> height;
get_height(this->root, height);
std::queue<BitVectorNode*> que;
que.emplace(this->root);
while (not que.empty()) {
BitVectorNode *node = que.front(); que.pop();
if (not node->is_valid_node()) {
if (verbose) {
std::cerr << "node " << node->no << " is invalid node" << std::endl;
std::cerr << node->info() << std::endl;
}
return false;
}
if (not node->is_leaf) {
auto left_height = height[node->left->no];
auto right_height = height[node->right->no];
// バランスの値が正しいかチェック
if (node->balance != right_height - left_height) {
if (verbose) {
std::cerr << "node" << node->no << "'s balance is " << node->balance << "(left height:" << left_height << ", right height:" << right_height << ")" << std::endl;
}
return false;
}
// AVL木の制約を満たしていない
if (node->balance < -1 or 1 < node->balance) {
if (verbose) {
std::cerr << "node" << node->no << "is not balanced." << "(balance:" << node->balance << ", left height:" << left_height << ", right height:" << right_height << ")" << std::endl;
}
return false;
}
que.emplace(node->left);
que.emplace(node->right);
}
else {
// バランスの値が正しいかチェック
if (node->balance != 0) {
if (verbose) {
std::cerr << "node " << node->no << "'s balance is not 0" << std::endl;
}
return false;
}
}
}
return true;
}
private:
uint64_t access(const BitVectorNode *node, uint64_t pos) {
if (node->is_leaf) {
assert(pos <= 2 * this->bits_size_limit);
return (node->bits >> pos) & (uint64_t)1;
}
if (pos < node->num) {
return this->access(node->left, pos);
} else {
return this->access(node->right, pos - node->num);
}
}
uint64_t rank1(const BitVectorNode *node, uint64_t pos, uint64_t ones) {
if (node->is_leaf) {
assert(node->bits_size >= pos);
const uint64_t mask = ((uint64_t)1 << pos) - 1;
return ones + popCount(node->bits & mask);
}
if (pos < node->num) {
return this->rank1(node->left, pos, ones);
} else {
return this->rank1(node->right, pos - node->num, ones + node->ones);
}
}
uint64_t select1(const BitVectorNode *node, uint64_t pos, uint64_t rank) {
if (node->is_leaf) {
return pos + this->selectInBlock(node->bits, rank) + 1;
}
if (rank <= node->ones) {
return this->select1(node->left, pos, rank);
} else {
return this->select1(node->right, pos + node->num, rank - node->ones);
}
}
uint64_t select0(const BitVectorNode *node, uint64_t pos, uint64_t rank) {
if (node->is_leaf) {
return pos + this->selectInBlock(~node->bits, rank) + 1;
}
if (rank <= (node->num - node->ones)) {
return this->select0(node->left, pos, rank);
} else {
return this->select0(node->right, pos + node->num, rank - (node->num - node->ones));
}
}
// bits(64bit)のrank番目(0-index)の1の数
uint64_t selectInBlock(uint64_t bits, uint64_t rank) {
const uint64_t x1 = bits - ((bits & 0xAAAAAAAAAAAAAAAALLU) >> (uint64_t)1);
const uint64_t x2 = (x1 & 0x3333333333333333LLU) + ((x1 >> (uint64_t)2) & 0x3333333333333333LLU);
const uint64_t x3 = (x2 + (x2 >> (uint64_t)4)) & 0x0F0F0F0F0F0F0F0FLLU;
uint64_t pos = 0;
for (;; pos += 8){
const uint64_t rank_next = (x3 >> pos) & 0xFFLLU;
if (rank <= rank_next) break;
rank -= rank_next;
}
const uint64_t v2 = (x2 >> pos) & 0xFLLU;
if (rank > v2) {
rank -= v2;
pos += 4;
}
const uint64_t v1 = (x1 >> pos) & 0x3LLU;
if (rank > v1){
rank -= v1;
pos += 2;
}
const uint64_t v0 = (bits >> pos) & 0x1LLU;
if (v0 < rank){
pos += 1;
}
return pos;
}
// nodeから辿れる葉のpos番目にbitをいれる(葉のbitの長さはlen)
// 高さの変化を返す
int64_t insert(BitVectorNode *node, BitVectorNode *parent, uint64_t bit, uint64_t pos, uint64_t len) {
assert(bit == 0 or bit == 1);
if (node->is_leaf) {
assert(pos <= 2 * this->bits_size_limit);
// posより左をとりだす
const uint64_t up_mask = (((uint64_t)1 << (len - pos)) - 1) << pos;
const uint64_t up = (node->bits & up_mask);
// posより右をとりだす
const uint64_t down_mask = ((uint64_t)1 << pos) - 1;
const uint64_t down = node->bits & down_mask;
node->bits = (up << (uint64_t)1) | (bit << pos) | down;
node->bits_size++;
assert(node->bits_size == len + 1);
// bitsのサイズが大きくなったので分割する
if (len + 1 >= 2 * bits_size_limit) {
this->splitNode(node, len + 1); // 引数のlenはinsert後の長さなので+1する
return 1;
}
return 0;
}
if (pos < node->num) {
const int64_t diff = this->insert(node->left, node, bit, pos, node->num);
node->num += 1;
node->ones += (bit == 1);
return achieveBalance(parent, node, diff, 0);
} else {
const int64_t diff = this->insert(node->right, node, bit, pos - node->num, len - node->num);
return achieveBalance(parent, node, 0, diff);
}
}
// nodeのpos番目のbitを削除する
// 消したbitと高さの変化のpairを返す
std::pair<uint64_t, int64_t> remove(BitVectorNode *node, BitVectorNode *parent, uint64_t pos, uint64_t len, uint64_t ones, bool allow_under_flow) {
if (node->is_leaf) {
// 消すとunder flowになるので消さない
if (node != this->root and len <= bits_size_limit / 2 and not allow_under_flow) {
return std::make_pair(NOTFOUND, 0);
}
assert(pos <= 2 * this->bits_size_limit);
assert(pos < len);
const uint64_t bit = (node->bits >> pos) & (uint64_t)1;
// posより左をとりだす(posを含まないようにする)
const uint64_t up_mask = (((uint64_t)1 << (len - pos - 1)) - 1) << (pos + 1);
const uint64_t up = (node->bits & up_mask);
// posより右をとりだす
const uint64_t down_mask = ((uint64_t)1 << pos) - 1;
const uint64_t down = node->bits & down_mask;
node->bits = (up >> (uint64_t)1) | down;
node->bits_size--;
assert(node->bits_size == len - 1);
return std::make_pair(bit, 0);
}
if (pos < node->num) {
const auto bit_diff = this->remove(node->left, node, pos, node->num, node->ones, allow_under_flow);
if (bit_diff.first == NOTFOUND) {
return bit_diff;
}
node->ones -= (bit_diff.first == 1);
// 左の子の葉のbitを1つ消したのでunder flowが発生している
if (node->num == bits_size_limit / 2) {
const auto b_d = remove(node->right, node, 0, len - bits_size_limit / 2, 0, false); // 右の葉の先頭bitを削る
// 右の葉もunder flowになって消せない場合は2つの葉を統合する
if (b_d.first == NOTFOUND) {
assert(node->left->is_leaf);
assert(node->left->bits_size == bits_size_limit / 2 - 1);
// 右の子から辿れる一番左の葉の先頭にleftのbitsを追加する
mergeNodes(node->right, 0, len - bits_size_limit / 2, node->left->bits, bits_size_limit / 2 - 1, node->ones, true);
this->replace(parent, node, node->right); // parentの子のnodeをnode->rightに置き換える
return std::make_pair(bit_diff.first, -1);
}
// 右の葉からとった先頭bitを左の葉の末尾にいれる
assert(node->left->bits_size == bits_size_limit / 2 - 1);
insert(node->left, node, b_d.first, bits_size_limit / 2 - 1, bits_size_limit / 2 - 1);
node->ones += (b_d.first == 1);
}
else {
node->num -= 1;
}
const int64_t diff = achieveBalance(parent, node, bit_diff.second, 0);
return std::make_pair(bit_diff.first, diff);
} else {
const auto bit_diff = this->remove(node->right, node, pos - node->num, len - node->num, ones - node->ones, allow_under_flow);
if (bit_diff.first == NOTFOUND) {
return bit_diff;
}
ones -= (bit_diff.first == 1);
// 右の子の葉のbitを1つ消したのでunder flowが発生する
if ((len - node->num) == bits_size_limit / 2) {
const auto b_d = remove(node->left, node, node->num - 1, node->num, 0, false); // 左の葉の末尾をbitを削る
// 左の葉もunder flowになって消せない場合は2つの葉を統合する
if (b_d.first == NOTFOUND) {
assert(node->right->is_leaf);
assert(node->right->bits_size == bits_size_limit / 2 - 1);
// 左の子から辿れる一番右の葉の末尾にleftのbitsを追加する
mergeNodes(node->left, node->num, node->num, node->right->bits, bits_size_limit / 2 - 1, ones - node->ones, false);
this->replace(parent, node, node->left); // parentの子のnodeをnode->leftに置き換える
return std::make_pair(bit_diff.first, -1);
}
// 左の葉からとった最後尾bitを右の葉の先頭にいれる
insert(node->right, node, b_d.first, 0, bits_size_limit / 2 - 1);
node->num -= 1;
node->ones -= (b_d.first == 1);
}
const int64_t diff = achieveBalance(parent, node, 0, bit_diff.second);
return std::make_pair(bit_diff.first, diff);
}
}
// pos番目のbitを1にする.0から1への反転が起きたらtrueを返す
bool bitset(BitVectorNode *node, uint64_t pos) {
if (node->is_leaf) {
assert(pos <= 2 * this->bits_size_limit);
const uint64_t bit = (node->bits >> pos) & 1;
if (bit == 1) {
return false;
}
node->bits |= ((uint64_t)1 << pos);
return true;
}
if (pos < node->num) {
bool flip = this->bitset(node->left, pos);
node->ones += flip;
return flip;
} else {
return this->bitset(node->right, pos - node->num);
}
}
// pos番目のbitを0にする.1から0への反転が起きたらtrueを返す
bool bitclear(BitVectorNode *node, uint64_t pos) {
if (node->is_leaf) {
assert(pos <= 2 * this->bits_size_limit);
const uint64_t bit = (node->bits >> pos) & 1;
if (bit == 0) {
return false;
}
node->bits &= ~((uint64_t)1 << pos);
return true;
}
if (pos < node->num) {
const bool flip = this->bitclear(node->left, pos);
node->ones -= flip;
return flip;
} else {
return this->bitclear(node->right, pos - node->num);
}
}
// nodeを2つの葉に分割する
void splitNode(BitVectorNode *node, uint64_t len) {
assert(node->is_leaf);
assert(node->bits_size == len);
// 左の葉
const uint64_t left_size = len / 2;
const uint64_t left_bits = node->bits & (((uint64_t)1 << left_size) - 1);
node->left = new BitVectorNode(left_bits, left_size, true);
// 右の葉
const uint64_t right_size = len - left_size;
const uint64_t right_bits = node->bits >> left_size;
node->right = new BitVectorNode(right_bits, right_size, true);
// nodeを内部ノードにする
node->is_leaf = false;
node->num = left_size;
node->ones = this->popCount(left_bits);
node->bits = 0;
node->bits_size = 0;
}
// nodeから辿れる葉のpos番目にbitsを格納する
void mergeNodes(BitVectorNode *node, uint64_t pos, uint64_t len, uint64_t bits, uint64_t s, uint64_t ones, bool left) {
if (node->is_leaf) {
if (left) {
node->bits = (node->bits << s) | bits;
}
else {
assert(len == node->bits_size);
node->bits = node->bits | (bits << len);
}
node->bits_size += s;
return;
}
if (pos < node->num) {
node->num += s;
node->ones += ones;
mergeNodes(node->left, pos, node->num, bits, s, ones, left);
}
else {
mergeNodes(node->right, pos, len - node->num, bits, s, ones, left);
}
}
// nodeの左の高さがleftHeightDiffだけ変わり,右の高さがrightHeightDiffだけ変わったときにnodeを中心に回転させる
// 高さの変化を返す
int64_t achieveBalance(BitVectorNode *parent, BitVectorNode *node, int64_t leftHeightDiff, int64_t rightHeightDiff) {
assert(-1 <= node->balance and node->balance <= 1);
assert(-1 <= leftHeightDiff and leftHeightDiff <= 1);
assert(-1 <= rightHeightDiff and rightHeightDiff <= 1);
if (leftHeightDiff == 0 and rightHeightDiff == 0) {
return 0;
}
int64_t heightDiff = 0;
// 左が高いときに,左が高くなる or 右が高いときに右が高くなる
if ((node->balance <= 0 and leftHeightDiff > 0) or (node->balance >= 0 and rightHeightDiff > 0)) {
++heightDiff;
}
// 左が高いときに左が低くなる or 右が高いときに右が低くなる
if ((node->balance < 0 and leftHeightDiff < 0) or (node->balance > 0 and rightHeightDiff < 0)) {
--heightDiff;
}
node->balance += -leftHeightDiff + rightHeightDiff;
assert(-2 <= node->balance and node->balance <= 2);
// 左が2高い
if (node->balance == -2) {
assert(-1 <= node->left->balance and node->left->balance <= 1);
if (node->left->balance != 0) {
heightDiff--;
}
if (node->left->balance == 1) {
replace(node, node->left, rotateLeft(node->left));
}
replace(parent, node, rotateRight(node));
}
// 右が2高い
else if (node->balance == 2) {
assert(-1 <= node->right->balance and node->right->balance <= 1);
if (node->right->balance != 0) {
heightDiff--;
}
if (node->right->balance == -1) {
replace(node, node->right, rotateRight(node->right));
}
replace(parent, node, rotateLeft(node));
}
return heightDiff;
}
// node Bを中心に左回転する.新しい親を返す
BitVectorNode *rotateLeft(BitVectorNode *B) {
BitVectorNode *D = B->right;
const int64_t heightC = 0;
const int64_t heightE = heightC + D->balance;
const int64_t heightA = std::max(heightC, heightE) + 1 - B->balance;
B->right = D->left;
D->left = B;
B->balance = heightC - heightA;
D->num += B->num;
D->ones += B->ones;
D->balance = heightE - (std::max(heightA, heightC) + 1);
assert(-2 <= B->balance and B->balance <= 2);
assert(-2 <= D->balance and D->balance <= 2);
return D;
}
// node Dを中心に右回転する.新しい親を返す
BitVectorNode *rotateRight(BitVectorNode *D) {
BitVectorNode *B = D->left;
const int64_t heightC = 0;
const int64_t heightA = heightC - B->balance;
const int64_t heightE = std::max(heightA, heightC) + 1 + D->balance;
D->left = B->right;
B->right = D;
D->num -= B->num;
D->ones -= B->ones;
D->balance = heightE - heightC;
B->balance = std::max(heightC, heightE) + 1 - heightA;
assert(-2 <= B->balance and B->balance <= 2);
assert(-2 <= D->balance and D->balance <= 2);
return B;
}
// parentの子のoldNodeをnewNodeに置き換える
void replace(BitVectorNode *parent, BitVectorNode *oldNode, BitVectorNode *newNode) {
if (parent == nullptr) {
this->root = newNode;
return;
}
if (parent->left == oldNode) {
parent->left = newNode;
}
else if (parent->right == oldNode) {
parent->right = newNode;
}
else {
throw "old node is not child";
}
}
uint64_t popCount(uint64_t x) {
x = (x & 0x5555555555555555ULL) + ((x >> (uint64_t)1) & 0x5555555555555555ULL);
x = (x & 0x3333333333333333ULL) + ((x >> (uint64_t)2) & 0x3333333333333333ULL);
x = (x + (x >> (uint64_t)4)) & 0x0f0f0f0f0f0f0f0fULL;
x = x + (x >> (uint64_t)8);
x = x + (x >> (uint64_t)16);
x = x + (x >> (uint64_t)32);
return x & 0x7FLLU;
}
// 各ノードの高さ(一番遠い葉からの距離)を取得(debug用)
uint64_t get_height(BitVectorNode *node, std::map<uint64_t, uint64_t> &height) {
if (node->is_leaf) {
return 0;
}
if (height.find(node->no) != height.end()) {
return height[node->no];
}
auto left_height = get_height(node->left, height);
auto right_height = get_height(node->right, height);
return height[node->no] = std::max(left_height, right_height) + 1;
}
};
class WaveletNode {
public:
weak_ptr<WaveletNode> parent;
shared_ptr<WaveletNode> left;
shared_ptr<WaveletNode> right;
DynamicBitVector bitVector;
WaveletNode(weak_ptr<WaveletNode> parent) : parent(parent), left(nullptr), right(nullptr) {}
WaveletNode() : left(nullptr), right(nullptr) {}
};
class DynamicWaveletTree {
public:
shared_ptr<WaveletNode> root;
vector<weak_ptr<WaveletNode>> leaves;
uint64_t size;
const uint64_t maximum_element; // 最大の数値
uint64_t bit_size; // 文字を表すのに必要なbit数
public:
DynamicWaveletTree(uint64_t maximum_element) : root(new WaveletNode), size(0), maximum_element(maximum_element + 1) {
this->bit_size = this->get_num_of_bit(maximum_element);
this->leaves.resize(maximum_element);
}
DynamicWaveletTree(uint64_t maximum_element, const std::vector<uint64_t> &array) : root(new WaveletNode), size(array.size()), maximum_element(maximum_element + 1) {
this->bit_size = this->get_num_of_bit(maximum_element);
this->leaves.resize(maximum_element);
for (int i = 0; i < array.size(); ++i) {
uint64_t c = array[i];
auto node = this->root;
for (int j = 0; j < bit_size; ++j) {
const uint64_t bit = (c >> (bit_size - j - 1)) & 1; // 上からj番目のbit
node->bitVector.push_back(bit);
if (j == bit_size - 1) {
break;
}
if (bit == 0) {
if (node->left == nullptr) {
shared_ptr<WaveletNode> left(new WaveletNode(node));
node->left = left;
}
node = node->left;
}
else {
if (node->right == nullptr) {
shared_ptr<WaveletNode> right(new WaveletNode(node));
node->right = right;
}
node = node->right;
}
}
this->leaves[c] = node;
}
}
uint64_t access(uint64_t pos) {
assert(pos < this->size);
auto node = this->root;
uint64_t c = 0;
for (int i = 0; i < bit_size; ++i) {
const uint64_t bit = node->bitVector.access(pos); // T[pos]のi番目のbit
c = (c <<= 1) | bit;
pos = node->bitVector.rank(bit, pos);
if (bit == 0) {
node = node->left;
}
else {
node = node->right;
}
}
return c;
}
// v[0, pos)のcの数
uint64_t rank(uint64_t c, uint64_t pos) {
assert(pos <= size);
if (c >= maximum_element) {
return 0;
}
auto node = this->root;
for (uint64_t i = 0; i < bit_size; ++i) {
const uint64_t bit = (c >> (bit_size - i - 1)) & 1; // 上からi番目のbit
pos = node->bitVector.rank(bit, pos); // cのi番目のbitと同じ数値の数
node = bit == 0 ? node->left : node->right;
}
return pos;
}
// i番目のcの位置 + 1を返す。rankは1-origin
uint64_t select(uint64_t c, uint64_t rank) {
assert(rank > 0);
if (c >= maximum_element) {
return NOTFOUND;
}
auto node = this->leaves[c];
for (int i = 0; i < bit_size; ++i) {
uint64_t bit = ((c >> i) & 1); // 下からi番目のbit
auto n = node.lock();
rank = n->bitVector.select(bit, rank);
node = n->parent;
}
return rank;
}
// posにcを挿入する
void insert(uint64_t pos, uint64_t c) {
assert(pos <= this->size);
auto node = this->root;
for (uint64_t i = 0; i < bit_size; ++i) {
const uint64_t bit = (c >> (bit_size - i - 1)) & 1; // 上からi番目のbit
node->bitVector.insert(pos, bit);
pos = node->bitVector.rank(bit, pos);
if (i == bit_size - 1) {
break;
}
if (bit == 0) {
if (node->left == nullptr) {
shared_ptr<WaveletNode> left(new WaveletNode(node));
node->left = left;
}
node = node->left;
}
else {
if (node->right == nullptr) {
shared_ptr<WaveletNode> right(new WaveletNode(node));
node->right = right;
}
node = node->right;
}
}
this->size++;
this->leaves[c] = node;
}
// 末尾にcを追加する
void push_back(uint64_t c) {
this->insert(this->size, c);
}
// posを削除する
uint64_t erase(uint64_t pos) {
assert(pos < this->size);
auto node = this->root;
uint64_t c = 0;
for (uint64_t i = 0; i < bit_size; ++i) {
uint64_t bit = node->bitVector.access(pos); // もとの数値のi番目のbit
c = (c <<= 1) | bit;
auto next_pos = node->bitVector.rank(bit, pos);
node->bitVector.erase(pos);
node = bit == 0 ? node->left : node->right;
pos = next_pos;
}
this->size--;
return c;
}
void update(uint64_t pos, uint64_t c) {
this->erase(pos);
this->insert(pos, c);
}
// v[begin_pos, end_pos)でk番目に小さい数値を返す(kは0-origin)
// つまり小さい順に並べてk番目の値
uint64_t quantileRange(uint64_t begin_pos, uint64_t end_pos, uint64_t k) {
if ((end_pos > size || begin_pos >= end_pos) || (k >= end_pos - begin_pos)) {
return NOTFOUND;
}
auto node = this->root;
uint64_t val = 0;
for (uint64_t i = 0; i < bit_size; ++i) {
const uint64_t num_of_zero_begin = node->bitVector.rank(0, begin_pos);
const uint64_t num_of_zero_end = node->bitVector.rank(0, end_pos);
const uint64_t num_of_zero = num_of_zero_end - num_of_zero_begin; // beginからendまでにある0の数
const uint64_t bit = (k < num_of_zero) ? 0 : 1; // k番目の値の上からi番目のbitが0か1か
if (bit == 0) {
node = node->left;
begin_pos = num_of_zero_begin;
end_pos = num_of_zero_begin + num_of_zero;
}
else {
node = node->right;
k -= num_of_zero;
begin_pos = begin_pos - num_of_zero_begin;
end_pos = end_pos - num_of_zero_end;
}
val = ((val << 1) | bit);
}
return val;
}
private:
uint64_t get_num_of_bit(uint64_t x) {
if (x == 0) return 0;
x--;
uint64_t bit_num = 0;
while (x >> bit_num) {
++bit_num;
}
return bit_num;
}
};
int main() {
cin.tie(0);
ios::sync_with_stdio(false);
DynamicWaveletTree dwt(200000);
int Q;
cin >> Q;
for(int i = 0; i < Q; ++i) {
int T, X;
cin >> T >> X;
if (T == 1) {
dwt.push_back(X);
}
else {
auto c = dwt.quantileRange(0, dwt.size, X - 1);
cout << c << endl;
auto idx = dwt.select(c, 1);
dwt.erase(idx - 1);
}
}
return 0;
}
Submission Info
Submission Time |
|
Task |
C - データ構造 |
User |
MitI_7 |
Language |
C++14 (GCC 5.4.1) |
Score |
100 |
Code Size |
37287 Byte |
Status |
AC |
Exec Time |
1395 ms |
Memory |
55680 KB |
Judge Result
Set Name |
Sample |
All |
Score / Max Score |
0 / 0 |
100 / 100 |
Status |
|
|
Set Name |
Test Cases |
Sample |
sample_01.txt, sample_02.txt |
All |
sample_01.txt, sample_02.txt, subtask1_01.txt, subtask1_02.txt, subtask1_03.txt, subtask1_04.txt, subtask1_05.txt, subtask1_06.txt, subtask1_07.txt, subtask1_08.txt, subtask1_09.txt, subtask1_10.txt, subtask1_11.txt, subtask1_12.txt, subtask1_13.txt, subtask1_14.txt, subtask1_15.txt, subtask1_16.txt |
Case Name |
Status |
Exec Time |
Memory |
sample_01.txt |
AC |
3 ms |
3456 KB |
sample_02.txt |
AC |
3 ms |
3456 KB |
subtask1_01.txt |
AC |
3 ms |
3328 KB |
subtask1_02.txt |
AC |
3 ms |
3328 KB |
subtask1_03.txt |
AC |
4 ms |
3584 KB |
subtask1_04.txt |
AC |
40 ms |
10368 KB |
subtask1_05.txt |
AC |
93 ms |
15616 KB |
subtask1_06.txt |
AC |
4 ms |
3840 KB |
subtask1_07.txt |
AC |
565 ms |
40320 KB |
subtask1_08.txt |
AC |
591 ms |
36480 KB |
subtask1_09.txt |
AC |
557 ms |
36096 KB |
subtask1_10.txt |
AC |
990 ms |
48896 KB |
subtask1_11.txt |
AC |
994 ms |
48896 KB |
subtask1_12.txt |
AC |
582 ms |
55680 KB |
subtask1_13.txt |
AC |
1330 ms |
44032 KB |
subtask1_14.txt |
AC |
1395 ms |
44032 KB |
subtask1_15.txt |
AC |
574 ms |
47104 KB |
subtask1_16.txt |
AC |
524 ms |
36736 KB |