Maximum Sum Bst

p
from __future__ import annotations

import sys
from dataclasses import dataclass

INT_MIN = -sys.maxsize + 1
INT_MAX = sys.maxsize - 1


@dataclass
class TreeNode:
    val: int = 0
    left: TreeNode | None = None
    right: TreeNode | None = None


def max_sum_bst(root: TreeNode | None) -> int:
    """
    The solution traverses a binary tree to find the maximum sum of
    keys in any subtree that is a Binary Search Tree (BST). It uses
    recursion to validate BST properties and calculates sums, returning
    the highest sum found among all valid BST subtrees.

    >>> t1 = TreeNode(4)
    >>> t1.left = TreeNode(3)
    >>> t1.left.left = TreeNode(1)
    >>> t1.left.right = TreeNode(2)
    >>> print(max_sum_bst(t1))
    2
    >>> t2 = TreeNode(-4)
    >>> t2.left = TreeNode(-2)
    >>> t2.right = TreeNode(-5)
    >>> print(max_sum_bst(t2))
    0
    >>> t3 = TreeNode(1)
    >>> t3.left = TreeNode(4)
    >>> t3.left.left = TreeNode(2)
    >>> t3.left.right = TreeNode(4)
    >>> t3.right = TreeNode(3)
    >>> t3.right.left = TreeNode(2)
    >>> t3.right.right = TreeNode(5)
    >>> t3.right.right.left = TreeNode(4)
    >>> t3.right.right.right = TreeNode(6)
    >>> print(max_sum_bst(t3))
    20
    """
    ans: int = 0

    def solver(node: TreeNode | None) -> tuple[bool, int, int, int]:
        """
        Returns the maximum sum by making recursive calls
        >>> t1 = TreeNode(1)
        >>> print(solver(t1))
        1
        """
        nonlocal ans

        if not node:
            return True, INT_MAX, INT_MIN, 0  # Valid BST, min, max, sum

        is_left_valid, min_left, max_left, sum_left = solver(node.left)
        is_right_valid, min_right, max_right, sum_right = solver(node.right)

        if is_left_valid and is_right_valid and max_left < node.val < min_right:
            total_sum = sum_left + sum_right + node.val
            ans = max(ans, total_sum)
            return True, min(min_left, node.val), max(max_right, node.val), total_sum

        return False, -1, -1, -1  # Not a valid BST

    solver(root)
    return ans


if __name__ == "__main__":
    import doctest

    doctest.testmod()