Set with range sums : Binary Search Trees

 Set with range sums

Problem Introduction

In this problem, your goal is to implement a data structure to store a set of integers and quickly compute

range sums.

Problem Description

Task. Implement a data structure that stores a set 𝑆 of integers with the following allowed operations:

∙ add(𝑖) — add integer 𝑖 into the set 𝑆 (if it was there already, the set doesn’t change).

∙ del(𝑖) — remove integer 𝑖 from the set 𝑆 (if there was no such element, nothing happens).

∙ find(𝑖) — check whether 𝑖 is in the set 𝑆 or not.

∙ sum(𝑙, 𝑟) — output the sum of all elements 𝑣 in 𝑆 such that 𝑙 ≤ 𝑣 ≤ 𝑟.

Input Format. Initially the set 𝑆 is empty. The first line contains 𝑛 — the number of operations. The next

𝑛 lines contain operations. Each operation is one of the following:

∙ “+ i" — which means add some integer (not 𝑖, see below) to 𝑆,

∙ “- i" — which means del some integer (not 𝑖, see below)from 𝑆,

∙ “? i" — which means find some integer (not 𝑖, see below)in 𝑆,

∙ “s l r" — which means compute the sum of all elements of 𝑆 within some range of values (not

from 𝑙 to 𝑟, see below).

However, to make sure that your solution can work in an online fashion, each request will actually

depend on the result of the last sum request. Denote 𝑀 = 1 000 000 001. At any moment, let 𝑥 be

the result of the last sum operation, or just 0 if there were no sum operations before. Then

∙ “+ i" means add((𝑖 + 𝑥) mod 𝑀),

∙ “- i" means del((𝑖 + 𝑥) mod 𝑀),

∙ “? i" means find((𝑖 + 𝑥) mod 𝑀),

∙ “s l r" means sum((𝑙 + 𝑥) mod 𝑀,(𝑟 + 𝑥) mod 𝑀).

Constraints. 1 ≤ 𝑛 ≤ 100 000; 0 ≤ 𝑖 ≤ 109

.

Output Format. For each find request, just output “Found" or “Not found" (without quotes; note that the

first letter is capital) depending on whether (𝑖 + 𝑥) mod 𝑀 is in 𝑆 or not. For each sum query, output

the sum of all the values 𝑣 in 𝑆 such that ((𝑙+𝑥) mod 𝑀) ≤ 𝑣 ≤ ((𝑟+𝑥) mod 𝑀) (it is guaranteed that

in all the tests ((𝑙 + 𝑥) mod 𝑀) ≤ ((𝑟 + 𝑥) mod 𝑀)), where 𝑥 is the result of the last sum operation

or 0 if there was no previous sum operation.

Time Limits.

language C C++ Java Python C# Haskell JavaScript Ruby Scala

time (sec) 1 1 4 120 1.5 2 120 120 4

Memory Limit. 512MB.

Sample 1.

Input:

15

? 1

+ 1

? 1

+ 2

s 1 2

+ 1000000000

? 1000000000

- 1000000000

? 1000000000

s 999999999 1000000000

- 2

? 2

- 0

+ 9

s 0 9

Output:

Not found

Found

3

Found

Not found

1

Not found

10

Explanation:

For the first 5 queries, 𝑥 = 0. For the next 5 queries, 𝑥 = 3. For the next 5 queries, 𝑥 = 1. The actual

list of operations is:

find(1)

add(1)

find(1)

add(2)

sum(1, 2) → 3

add(2)

find(2) → Found

del(2)

find(2) → Not found

sum(1, 2) → 1

del(3)

find(3) → Not found

del(1)

add(10)

sum(1, 10) → 10

Adding the same element twice doesn’t change the set. Attempts to remove an element which is not

in the set are ignored.



Solution :


# 26\2\2021
#It's so :()

import sys
class Vertex:

    def __init__(self, key, sum, left, right, parent):
        self.key, self.sum, self.left, self.right, self.parent = \
            key, sum, left, right, parent

class SplayTree:

    @staticmethod
    def update(v):
        if v is None:
            return
        v.sum = v.key + (v.left.sum if v.left is not None else 0) + (
            v.right.sum if v.right is not None else 0)
        if v.left is not None:
            v.left.parent = v
        if v.right is not None:
            v.right.parent = v

    @classmethod
    def _small_rotation(cls, v):
        parent = v.parent
        if parent is None:
            return
        grandparent = v.parent.parent
        if parent.left == v:
            m = v.right
            v.right = parent
            parent.left = m
        else:
            m = v.left
            v.left = parent
            parent.right = m
        cls.update(parent)
        cls.update(v)
        v.parent = grandparent
        if grandparent is not None:
            if grandparent.left == parent:
                grandparent.left = v
            else:
                grandparent.right = v

    @classmethod
    def _big_rotation(cls, v):
        if v.parent.left == v and v.parent.parent.left == v.parent:
            cls._small_rotation(v.parent)
            cls._small_rotation(v)
        elif v.parent.right == v and v.parent.parent.right == v.parent:
            cls._small_rotation(v.parent)
            cls._small_rotation(v)
        else:
            cls._small_rotation(v)
            cls._small_rotation(v)

    @classmethod
    def splay(cls, v):
        if v is None:
            return None
        while v.parent is not None:
            if v.parent.parent is None:
                cls._small_rotation(v)
                break
            cls._big_rotation(v)
        return v

    @classmethod
    def find(cls, root, key):
        v = root
        last = root
        next_ = None
        while v is not None:
            if v.key >= key and (next_ is None or v.key < next_.key):
                next_ = v
            last = v
            if v.key == key:
                break
            if v.key < key:
                v = v.right
            else:
                v = v.left
        root = cls.splay(last)
        return next_, root

    @classmethod
    def split(cls, root, key):
        result, root = SplayTree.find(root, key)
        if result is None:
            return root, None
        right = cls.splay(result)
        left = right.left
        right.left = None
        if left is not None:
            left.parent = None
        cls.update(left)
        cls.update(right)
        return left, right

    @classmethod
    def merge(cls, left, right):
        if left is None:
            return right
        if right is None:
            return left
        while right.left is not None:
            right = right.left
        right = cls.splay(right)
        right.left = left
        cls.update(right)
        return right


class Set:
    root = None

    def insert(self, key):
        left, right = SplayTree.split(self.root, key)
        new_vertex = None
        if right is None or right.key != key:
            new_vertex = Vertex(key, key, None, None, None)
        self.root = SplayTree.merge(SplayTree.merge(left, new_vertex), right)

    def erase(self, key):
        if self.search(key) is None:
            return

        SplayTree.splay(self.root)
        self.root = SplayTree.merge(self.root.left, self.root.right)
        if self.root is not None:
            self.root.parent = None

    def search(self, key):
        result, self.root = SplayTree.find(self.root, key)
        if result is None or result.key != key:
            return None
        return result.key

    def sum(self, fr, to):
        left, middle = SplayTree.split(self.root, fr)
        middle, right = SplayTree.split(middle, to + 1)

        if middle is None:
            ans = 0
            self.root = SplayTree.merge(left, right)
        else:
            ans = middle.sum
            self.root = SplayTree.merge(SplayTree.merge(left, middle), right)

        return ans


if __name__ == "__main__":
    n = int(sys.stdin.readline())

    last_sum_result = 0
    MODULO = 1000000001

    s = Set()
    for i in range(n):
        line = sys.stdin.readline().split()
        if line[0] == "+":
            x = int(line[1])
            s.insert((x + last_sum_result) % MODULO)
        elif line[0] == "-":
            x = int(line[1])
            s.erase((x + last_sum_result) % MODULO)
        elif line[0] == "?":
            x = int(line[1])
            print(
                "Found" if s.search(
                    (x + last_sum_result) % MODULO) is not None else "Not found"
            )
        elif line[0] == "s":
            l = int(line[1])
            r = int(line[2])
            res = s.sum((l + last_sum_result) % MODULO,
                        (r + last_sum_result) % MODULO)
            print(res)
            last_sum_result = res % MODULO