1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
use super::*;

trait IntoTree<'a, Trees: MonadTrees<'a>> {
    fn into_tree(self, trees: &Trees) -> BTWrap<'a, Trees, Trees::Tree>;
}

struct T<T>(T);

impl<'a, Trees: MonadTrees<'a>> IntoTree<'a, Trees> for T<Trees::Tree> {
    fn into_tree(self, _trees: &Trees) -> BTWrap<'a, Trees, Trees::Tree> {
        Trees::pure(self.0)
    }
}

impl<'a, Trees: BinaryTreesAvl<'a>, L: IntoTree<'a, Trees>, R: IntoTree<'a, Trees>>
    IntoTree<'a, Trees> for (L, Trees::Key, R)
{
    fn into_tree(self, trees: &Trees) -> BTWrap<'a, Trees, Trees::Tree> {
        let trees = trees.clone();
        Trees::T::bind2(
            self.0.into_tree(&trees),
            self.2.into_tree(&trees),
            move |tl, tr| trees.join_key_balanced_tree(tl, self.1, tr),
        )
    }
}

trait BinaryTreesAvlExt<'a>: BinaryTreesAvl<'a> {
    fn make_node(
        self,
        itl: impl IntoTree<'a, Self>,
        key: Self::Key,
        itr: impl IntoTree<'a, Self>,
    ) -> BTWrap<'a, Self, Self::Node> {
        Self::T::bind2(itl.into_tree(&self), itr.into_tree(&self), |tl, tr| {
            self.join_key_balanced(tl, key, tr)
        })
    }
}

impl<'a, Trees: BinaryTreesAvl<'a>> BinaryTreesAvlExt<'a> for Trees {}

enum Balancing {
    Balanced,
    RightBiased,
    LeftBiased,
}

fn balancing(hl: u64, hr: u64) -> Balancing {
    match (hl.saturating_sub(hr), hr.saturating_sub(hl)) {
        (0, 0) | (0, 1) | (1, 0) => Balancing::Balanced,
        (0, _) => Balancing::RightBiased,
        (_, 0) => Balancing::LeftBiased,
        (_, _) => unreachable!(),
    }
}

pub trait BinaryTreesAvl<'a>:
    TreesHeightError<'a> + BinaryTreesTreeOf<'a> + BinaryTreesTryJoin<'a>
{
    fn assume_node(&self, tree: &Self::Tree) -> BTWrap<'a, Self, Self::Node> {
        match self.refer(tree) {
            Some(reference) => self.resolve(&reference),
            None => self.height_error(HeightError::LeafHeight(self.height(tree))),
        }
    }

    fn assume_bind<T: Send>(
        self,
        tree: &Self::Tree,
        f: impl 'a + Send + FnOnce(Self, Self::Tree, Self::Key, Self::Tree) -> BTWrap<'a, Self, T>,
    ) -> BTWrap<'a, Self, T> {
        Self::bind(self.assume_node(tree), move |node| {
            let (tl, tr, key) = self.split(&node);
            f(self, tl, key, tr)
        })
    }

    fn join_key_balanced_tree(
        &self,
        tl: Self::Tree,
        key: Self::Key,
        tr: Self::Tree,
    ) -> BTWrap<'a, Self, Self::Tree> {
        self.clone()
            .tree_bind(self.clone().join_key_balanced(tl, key, tr))
    }

    fn join_key_balanced(
        self,
        tl: Self::Tree,
        key: Self::Key,
        tr: Self::Tree,
    ) -> BTWrap<'a, Self, Self::Node> {
        match balancing(self.height(&tl), self.height(&tr)) {
            Balancing::Balanced => self.try_join(tl, key, tr),
            Balancing::RightBiased => self.assume_bind(&tr, |ctx, trl, kr, trr| {
                if ctx.height(&trl) > ctx.height(&trr) {
                    ctx.assume_bind(&trl, |ctx, trll, krl, trlr| {
                        ctx.make_node((T(tl), key, T(trll)), krl, (T(trlr), kr, T(trr)))
                    })
                } else {
                    ctx.make_node((T(tl), key, T(trl)), kr, T(trr))
                }
            }),
            Balancing::LeftBiased => self.assume_bind(&tl, |ctx, tll, kl, tlr| {
                if ctx.height(&tll) < ctx.height(&tlr) {
                    ctx.assume_bind(&tlr, |ctx, tlrl, klr, tlrr| {
                        ctx.make_node((T(tll), kl, T(tlrl)), klr, (T(tlrr), key, T(tr)))
                    })
                } else {
                    ctx.make_node(T(tll), kl, (T(tlr), key, T(tr)))
                }
            }),
        }
    }
}

impl<'a, BT: TreesHeightError<'a> + BinaryTreesTreeOf<'a> + BinaryTreesTryJoin<'a>>
    BinaryTreesAvl<'a> for BT
{
}