extendr_api/wrapper/
function.rs

1use super::*;
2use extendr_ffi::{get_closure_body, get_closure_env, get_closure_formals, Rf_lcons};
3/// Wrapper for creating functions (CLOSSXP).
4/// ```
5/// use extendr_api::prelude::*;
6/// test! {
7///     // Closures are functions.
8///     let expr = R!("function(a = 1, b) {c <- a + b}")?;
9///     let func = expr.as_function().unwrap();
10///
11///     let expected_formals = Pairlist::from_pairs(vec![("a", r!(1.0)), ("b", missing_arg().into())]);
12///     let expected_body = lang!(
13///         "{", lang!("<-", sym!(c), lang!("+", sym!(a), sym!(b))));
14///     assert_eq!(func.formals().unwrap(), expected_formals);
15///     assert_eq!(func.body().unwrap(), expected_body);
16///     assert_eq!(func.environment().unwrap(), global_env());
17///
18///     // Primitives can also be functions.
19///     let expr = R!("`~`")?;
20///     let func = expr.as_function().unwrap();
21///     assert_eq!(func.formals(), None);
22///     assert_eq!(func.body(), None);
23///     assert_eq!(func.environment(), None);
24/// }
25/// ```
26#[derive(PartialEq, Clone)]
27pub struct Function {
28    pub(crate) robj: Robj,
29}
30
31impl Function {
32    #[cfg(feature = "non-api")]
33    /// Make a function from parts.
34    /// ```
35    /// use extendr_api::prelude::*;
36    /// test! {
37    ///     let formals = pairlist!(a=NULL);
38    ///     let body = lang!("+", sym!(a), r!(1)).try_into()?;
39    ///     let env = global_env();
40    ///     let f = r!(Function::from_parts(formals, body, env )?);
41    ///     assert_eq!(f.call(pairlist!(a=1))?, r!(2));
42    /// }
43    /// ```
44    pub fn from_parts(formals: Pairlist, body: Language, env: Environment) -> Result<Self> {
45        single_threaded(|| unsafe {
46            let sexp = extendr_ffi::Rf_allocSExp(SEXPTYPE::CLOSXP);
47            let robj = Robj::from_sexp(sexp);
48            extendr_ffi::SET_FORMALS(sexp, formals.get());
49            extendr_ffi::SET_BODY(sexp, body.get());
50            extendr_ffi::SET_CLOENV(sexp, env.get());
51            Ok(Function { robj })
52        })
53    }
54
55    /// Do the equivalent of x(a, b, c)
56    /// ```
57    /// use extendr_api::prelude::*;
58    /// test! {
59    ///     let function = R!("function(a, b) a + b").unwrap().as_function().unwrap();
60    ///     assert_eq!(function.call(pairlist!(a=1, b=2)).unwrap(), r!(3));
61    /// }
62    /// ```
63    pub fn call(&self, args: Pairlist) -> Result<Robj> {
64        single_threaded(|| unsafe {
65            let call = Robj::from_sexp(Rf_lcons(self.get(), args.get()));
66            call.eval()
67        })
68    }
69
70    /// Get the formal arguments of the function or None if it is a primitive.
71    pub fn formals(&self) -> Option<Pairlist> {
72        unsafe {
73            if self.rtype() == Rtype::Function {
74                let sexp = self.robj.get();
75                Some(
76                    Robj::from_sexp(get_closure_formals(sexp))
77                        .try_into()
78                        .unwrap(),
79                )
80            } else {
81                None
82            }
83        }
84    }
85
86    /// Get the body of the function or None if it is a primitive.
87    pub fn body(&self) -> Option<Robj> {
88        unsafe {
89            if self.rtype() == Rtype::Function {
90                let sexp = self.robj.get();
91                Some(Robj::from_sexp(get_closure_body(sexp)))
92            } else {
93                None
94            }
95        }
96    }
97
98    /// Get the environment of the function or None if it is a primitive.
99    pub fn environment(&self) -> Option<Environment> {
100        unsafe {
101            if self.rtype() == Rtype::Function {
102                let sexp = self.robj.get();
103                Some(
104                    Robj::from_sexp(get_closure_env(sexp))
105                        .try_into()
106                        .expect("Should be an environment"),
107                )
108            } else {
109                None
110            }
111        }
112    }
113}
114
115impl std::fmt::Debug for Function {
116    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117        write!(f, "{}", self.deparse().unwrap())
118    }
119}