1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
//! Casting between trait objects.
//!
//! This is a simple implementation of casting between trait objects. The idea
//! is from the [intertrait](https://crates.io/crates/intertrait) crate.
//!
//! The intertrait crate utilizes [`linkme`](https://crates.io/crates/linkme)
//! to register the casters at link time. But here, all the casters are managed
//! by a seperate storage (which is a member in the context of IR). And yet only
//! casting to reference and mutable reference are supported.
//!
//! All the casters can be registered when the operation/type/dialect is
//! registered into the context. For verifiers, the casters will be registered
//! automatically when the operation/type is derived. For interfaces, those
//! declared when deriving will be registered automatically, and the others can
//! be registered by calling the `register_caster` macro in the `register`
//! function of the dialect.
//!
//! For the mechanism of this implementation, see the [`Caster`] and the
//! testcases.
use std::{
    any::{Any, TypeId},
    collections::HashMap,
};

use downcast_rs::Downcast;

/// The caster storage.
///
/// This stores the from concrete type, the `Caster<dyn Target>` type and the
/// caster instances.
#[derive(Default)]
pub struct CasterStorage(HashMap<(TypeId, TypeId), Box<dyn Any>>);

/// A caster for trait-to-trait casting.
///
/// Given a trait object and the `dyn Target` type, first upcast it to `dyn
/// Any`, then get the caster using the concrete type id fetched from the `Any`
/// object, and the type id of the concrete caster `Caster<dyn Target>`. The
/// just get the caster from the storage and do the cast.
///
/// The caster functions can be simply implemented by downcasting the `Any`
/// object to the concrete type and then just let rust do the rest type
/// checking.
///
/// # Example
///
/// A caster can be implemented like this:
///
/// ```rust,ignore
/// let caster = Caster::<dyn Target> {
///     cast_ref: |any| any.downcast_ref::<Concrete>().unwrap() as &dyn Target,
///     cast_mut: |any| any.downcast_mut::<Concrete>().unwrap() as &mut dyn Target,
/// }
/// ```
pub struct Caster<T: ?Sized + 'static> {
    /// Casting from any to the target trait.
    cast_ref: fn(&dyn Any) -> &T,
    /// Casting from any to the target trait mutably.
    cast_mut: fn(&mut dyn Any) -> &mut T,
}

impl<T: ?Sized + 'static> Caster<T> {
    /// Create a new caster.
    pub fn new(cast_ref: fn(&dyn Any) -> &T, cast_mut: fn(&mut dyn Any) -> &mut T) -> Self {
        Self { cast_ref, cast_mut }
    }
}

impl CasterStorage {
    /// Register a caster into the storage.
    pub fn register<S: ?Sized + 'static, T: ?Sized + 'static>(&mut self, caster: Caster<T>) {
        let concrete_id = TypeId::of::<S>();
        let caster_id = TypeId::of::<Caster<T>>();
        self.0.insert((concrete_id, caster_id), Box::new(caster));
    }

    /// Lookup a caster in the storage.
    ///
    /// This require the type id of the underlying concrete type id, so all the
    /// `dyn Source` type should be upcasted to `dyn Any` to get the id of
    /// the concrete type. Otehrwise, The type id of the `dyn Source` will
    /// used to lookup, which will lead to a `None`.
    fn lookup<T: ?Sized + 'static>(&self, id: TypeId) -> Option<&Caster<T>> {
        let caster_id = TypeId::of::<Caster<T>>();
        self.0
            .get(&(id, caster_id))
            .map(|c| c.downcast_ref().unwrap())
    }
}

pub trait CastRef {
    fn impls<T: ?Sized + 'static>(&self, caster_storage: &CasterStorage) -> bool;
    fn cast_ref<T: ?Sized + 'static>(&self, caster_storage: &CasterStorage) -> Option<&T>;
}

pub trait CastMut {
    fn cast_mut<T: ?Sized + 'static>(&mut self, caster_storage: &CasterStorage) -> Option<&mut T>;
}

impl<S: Downcast + ?Sized> CastRef for S {
    fn impls<T: ?Sized + 'static>(&self, caster_storage: &CasterStorage) -> bool {
        let any = self.as_any();
        caster_storage.lookup::<T>(any.type_id()).is_some()
    }

    fn cast_ref<T: ?Sized + 'static>(&self, caster_storage: &CasterStorage) -> Option<&T> {
        let any = self.as_any();
        caster_storage
            .lookup::<T>(any.type_id())
            .map(|c| (c.cast_ref)(any))
    }
}

impl<S: Downcast + ?Sized> CastMut for S {
    fn cast_mut<T: ?Sized + 'static>(&mut self, caster_storage: &CasterStorage) -> Option<&mut T> {
        let any = self.as_any_mut();
        caster_storage
            .lookup::<T>((*any).type_id())
            .map(|c| (c.cast_mut)(any))
    }
}

#[cfg(test)]
mod tests {
    use std::any::TypeId;

    use downcast_rs::Downcast;

    use super::{Caster, CasterStorage};
    use crate::support::cast::{CastMut, CastRef};

    struct ConcreteStruct {
        value: i32,
    }

    trait TraitFrom: Downcast {
        fn identity(&self) -> i32;
    }

    trait TraitTo {
        fn double(&self) -> i32;
        fn set_value(&mut self, value: i32);
    }

    impl TraitFrom for ConcreteStruct {
        fn identity(&self) -> i32 { self.value }
    }

    impl TraitTo for ConcreteStruct {
        fn double(&self) -> i32 { self.value * 2 }

        fn set_value(&mut self, value: i32) { self.value = value; }
    }

    #[test]
    fn test_0() {
        let mut casters = CasterStorage::default();

        casters.register::<ConcreteStruct, dyn TraitTo>(Caster {
            cast_ref: |any| any.downcast_ref::<ConcreteStruct>().unwrap(), // as &dyn TraitTo,
            cast_mut: |any| any.downcast_mut::<ConcreteStruct>().unwrap(), // as &mut dyn TraitTo,
        });

        dbg!(casters.0.keys());
        dbg!(TypeId::of::<ConcreteStruct>());
        dbg!(TypeId::of::<Caster<dyn TraitTo>>());
        dbg!(TypeId::of::<Box<dyn TraitFrom>>());
        dbg!(TypeId::of::<dyn TraitTo>());

        let mut from_obj: Box<dyn TraitFrom> = Box::new(ConcreteStruct { value: 5 });
        assert_eq!(from_obj.identity(), 5);
        // note the box need to be unwrapped by `as_ref`
        let to = from_obj.as_ref().cast_ref::<dyn TraitTo>(&casters).unwrap();
        assert_eq!(to.double(), 10);

        // note the `as_mut` operation.
        let to = from_obj.as_mut().cast_mut::<dyn TraitTo>(&casters).unwrap();
        to.set_value(114514);

        assert_eq!(to.double(), 229028);
        assert_eq!(from_obj.identity(), 114514);
    }
}