resources

implementation

point update

from math import gcd, inf, lcm
 
class SegmentTree:
    def __init__(self, nums: List[int], mode="sum", target=None):
        self.mode = mode
        self.target = target
        self.n = len(nums)
        self.nums = nums
        self.tree = [None] * (2 * self.n)
        self.build(1, 1, self.n)
 
    def __left_child(self, idx, tree_left, tree_right):
        # simple indexing: O(4*n)
        # return idx * 2
 
        # advanced indexing: O(2*n)
        # order the nodes in in-order traversal order
        # (root <nodes in left tree> <nodes in right tree>)
        # left node is idx + 1, and there are (mid - left + 1) leaves in the left tree
        # hence there are 2 * (mid - left + 1) - 1 nodes in the left tree
        # thus right node is at `idx + 2 * (mid - left + 1)`
        return idx + 1
 
    def __right_child(self, idx, tree_left, tree_right):
        # simple indexing:
        # return idx * 2 + 1
 
        # advanced indexing:
        # order the nodes in in-order traversal order
        # (root <nodes in left tree> <nodes in right tree>)
        # left node is idx + 1, and there are (mid - left + 1) leaves in the left tree
        # hence there are 2 * (mid - left + 1) - 1 nodes in the left tree
        # thus right node is at `idx + 2 * (mid - left + 1)`
        mid = (tree_left + tree_right) // 2
        return idx + 2 * (mid - tree_left + 1)
 
    def __combine(self, left, right):
        left_val, left_count = left
        right_val, right_count = right
        if self.mode == "sum":
            return (left_val + right_val, left_count + right_count)
        elif self.mode == "prod":
            return (left_val * right_val, left_count + right_count)
        elif self.mode == "gcd":
            return (gcd(left_val, right_val), left_count + right_count)
        elif self.mode == "lcm":
            return (lcm(left_val, right_val), left_count + right_count)
        elif self.mode == "max":
            if left_val > right_val:
                # if target is not specified, count the max value instead
                if self.target is None:
                    return left
                else:
                    return (left_val, left_count + right_count)
            elif left_val < right_val:
                # if target is not specified, count the max value instead
                if self.target is None:
                    return right
                else:
                    return (right_val, left_count + right_count)
            else:  # left_val == right_val
                return (left_val, left_count + right_count)
        elif self.mode == "min":
            if left_val < right_val:
                # if target is not specified, count the min value instead
                if self.target is None:
                    return left
                else:
                    return (left_val, left_count + right_count)
            elif left_val > right_val:
                # if target is not specified, count the min value instead
                if self.target is None:
                    return right
                else:
                    return (right_val, left_count + right_count)
            else:  # left_val == right_val
                return (left_val, left_count + right_count)
 
    def build(self, tree_idx, tree_left, tree_right):
        # leaf node
        if tree_left == tree_right:
            if self.target is None:
                self.tree[tree_idx] = (self.nums[tree_left - 1], 1)
            else:
                self.tree[tree_idx] = (
                    self.nums[tree_left - 1],
                    int(self.nums[tree_left - 1] == self.target),
                )
            return
 
        mid = (tree_left + tree_right) // 2
 
        # get indices of left and right children
        left_idx = self.__left_child(tree_idx, tree_left, tree_right)
        right_idx = self.__right_child(tree_idx, tree_left, tree_right)
 
        # recursively build the tree
        self.build(left_idx, tree_left, mid)
        self.build(right_idx, mid + 1, tree_right)
 
        # combine the left and right children
        self.tree[tree_idx] = self.__combine(self.tree[left_idx], self.tree[right_idx])
 
    def __query(self, tree_idx, tree_left, tree_right, query_left, query_right):
        if query_left > query_right:
            if self.mode == "sum":
                return (0, 0)
            if self.mode == "prod":
                return (1, 0)
            if self.mode == "gcd":
                return (0, 0)
            if self.mode == "lcm":
                return (1, 0)
            elif self.mode == "max":
                return (-inf, 0)
            elif self.mode == "min":
                return (inf, 0)
 
        # span is perfectly covered by a node
        if tree_left == query_left and tree_right == query_right:
            return self.tree[tree_idx]
 
        mid = (tree_left + tree_right) // 2
 
        # get indices of left and right children
        left_idx = self.__left_child(tree_idx, tree_left, tree_right)
        right_idx = self.__right_child(tree_idx, tree_left, tree_right)
 
        # recursively query the left and right subtrees
        tree_left = self.__query(
            left_idx, tree_left, mid, query_left, min(mid, query_right)
        )
        tree_right = self.__query(
            right_idx, mid + 1, tree_right, max(mid + 1, query_left), query_right
        )
 
        # combine the left and right subtrees
        return self.__combine(tree_left, tree_right)
 
    def query(self, left, right):
        return self.__query(1, 1, self.n, left + 1, right + 1)
 
    def __update(self, pos, val, tree_idx, tree_left, tree_right):
        # leaf node
        if tree_left == tree_right:
            if self.target is None:
                self.tree[tree_idx] = (val, 1)
            else:
                self.tree[tree_idx] = (val, int(val == self.target))
            return
 
        mid = (tree_left + tree_right) // 2
 
        # get indices of left and right children
        left_idx = self.__left_child(tree_idx, tree_left, tree_right)
        right_idx = self.__right_child(tree_idx, tree_left, tree_right)
 
        # recursively update the left or right subtrees
        if pos <= mid:
            self.__update(pos, val, left_idx, tree_left, mid)
        else:
            self.__update(pos, val, right_idx, mid + 1, tree_right)
 
        self.tree[tree_idx] = self.__combine(self.tree[left_idx], self.tree[right_idx])
 
    def update(self, pos, val):
        self.nums[pos] = val
        self.__update(pos + 1, val, 1, 1, self.n)
 
    def __find_kth(self, k, tree_idx, tree_left, tree_right):
        if tree_left == tree_right:
            return tree_left - 1
 
        mid = (tree_left + tree_right) // 2
 
        left_idx = self.__left_child(tree_idx, tree_left, tree_right)
        right_idx = self.__right_child(tree_idx, tree_left, tree_right)
 
        _, left_count = self.tree[left_idx]
 
        if k <= left_count:
            return self.__find_kth(k, left_idx, tree_left, mid)
        else:
            return self.__find_kth(k - left_count, right_idx, mid + 1, tree_right)
 
    def find_kth(self, k):
        if k <= 0:
            return -1
        _, total_count = self.tree[1]
        if total_count < k:
            return -1
 
        return self.__find_kth(k, 1, 1, self.n)
 
 
segtree = SegmentTree([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], mode="sum", target=10)
print(segtree.query(1, 5))
segtree.update(5, 10)
print(segtree.query(1, 5))
 
segtree = SegmentTree([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], mode="max")
print(segtree.query(1, 5))
segtree.update(5, 10)
print(segtree.query(1, 5))
 
segtree = SegmentTree([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], mode="min")
print(segtree.query(1, 5))
segtree.update(5, 1)
print(segtree.query(1, 5))
 
segtree = SegmentTree([1, 1, 3, 4, 1, 6, 7, 1, 9, 10], mode="prod", target=1)
print(segtree.query(1, 5))
segtree.update(5, 1)
print(segtree.query(1, 5))
 
segtree = SegmentTree([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], mode="lcm", target=10)
print(segtree.query(1, 5))
segtree.update(5, 1)
print(segtree.query(1, 5))
 
segtree = SegmentTree([1, 1, 3, 4, 1, 6, 7, 1, 9, 10], mode="prod", target=1)
print(segtree.query(1, 5))
segtree.update(5, 1)
print(segtree.query(1, 5))