How we wrote xtensor 6/N: Operator Overloading
xtensor is a comprehensive framework for N-D arrays processing, including an extensible expression system, lazy evaluation, and many other features that cannot be covered in a single article. In this post, we focus on operator overloading.
In the previous article, we introduced expression templates and developed the skeleton of the xfunction
class, meant to represent a node of the abstract syntax tree of a tensor expression. We left aside how the xfunction
is passed the types of the operands as well as the way xfunction
objects are instantiated. This article gets back to these key points of the xtensor
‘s expression system.
Closure types
Since xfunction
is meant to represent a node of an abstract syntax tree, eventually it will be instantiated in the overload of an arithmetic operator or a mathematical function:
template <class E1, class E2>
inline auto operator+(const E1& op1,
const E2& op2)
{
using function_type = xfunction<std::plus<T>,
const E1&,
const E2&>;
return function_type(std::plus<>(), op1, op2);
}
where E1
and E2
can be arbitrary complex expressions (for now, xfunction
parametrized by other types, or xarray
). Therefore, the template arguments of xfunction
are the parameter types of the operator that instantiates it. Simple, isn’t it?
Unfortunately, things are a bit more complicated. Assume you want to pass a temporary returned from a function to the operator:
xarray<double> compute_array();
// ....
auto f = compute_array() + compute_array();
The temporaries returned from compute_array
are bound to const references (the parameters of operator+
), which is fine since these are stack-based references. Then they are “transfered” to the xfunction
object and are stored as data member references, which are NOT stack-based references. Therefore the rule that extends the lifetime of a temporary bound to a constant reference does not apply, resulting in dangling references.
We need to detect at call site whether the parameter is an lvalue or a temporary. Fortunately, C++11 introduced universal references, which allow to capture the type of a function argument and achieve perfect forwarding. Diving into the detail of universal references and perfect forwarding is beyond the scope of this article, you can find more about them on the isoccp blog.
However, we cannot use universal references to store the arguments in the xfunction
object:
template <class F, class... CT>
class xfunction
{
//...
std::tuple<CT&&...> m_e;
};
Indeed, the key to universal reference is the deducing context. When instantiating the tuple
type, the compiler already knows the types CT
and CT&&
refer to classical rvalue references.
Therefore, we need an intermediate structure which specifies how to store the operands in the xfunction
object: if the argument is an lvalue reference, it should be stored as a constant reference; otherwise, the argument is either an rvalue reference or a value, and it should be stored as a value. This (the way an argument is stored in the xfunction
object) is called the closure type. Two versions should be provided, one that keeps the constness of the argument, and one that always “constifies” the reference to avoid side effects. Thanks to the new traits metafunctions that were introduced in C++11, the implementation is straightforward:
template <class S>
struct closure_type
{
using underlying_type = std::conditional_t<
std::is_const<std::remove_reference_t<S>>::value,
std::add_const_t<std::decay_t<S>>,
std::decay_t<S>>;
using type = std::conditional_t<
std::is_lvalue_reference<S>::value,
underlying_type&,
underlying_type>;
};template <class S>
using closure_type_t = typename closure_type<S>::type;template <class S>
struct const_closure_type
{
using underlying_type = std::decay_t<S>;
using type = std::conditional_t<
std::is_lvalue_reference<S>:value,
std::add_const_t<underlying_type>&,
underlying_type>;
};template <class S>
using const_closure_type_t = typename const_closure_type<S>::type;
We now have all that we need to implement the addition operator:
template <class E1, class E2>
inline auto operator+(E1&& e1, E2&& e2)
{
using function_type = xfunction<std::plus<>,
const_closure_type_t<E1>,
const_closure_type_t<E2>>;
return function_type(std::plus<>(),
std::forward<E1>(e1),
std::forward<E2>(e2));
}
Operator overloads
Repeating the type of xfunction
for each operator overload (and later for mathematical functions) is going to be cumbersome. Since the only difference between the operators is the type of the functor, we can refactor the instantiation of the xfunction
object to avoid code duplication:
template <class F, class... CT>
struct xfunction_type
{
using type = xfunction<F, const_closure_type_t<CT>...>;
};template <class F, class... CT>
inline auto make_xfunction(CT&&... e) noexcept
{
using function_type = typename xfunction_type<F, CT...>::type;
return function_type(F(), std::forward<CT>(e)...);
}template <class E1, class E2>
inline auto operator+(E1&& e1, E2&& e2)
{
return make_xfunction<std::plus<>>(std::forward<E1>(e1),
std::forward<E2>(e2));
}
That’s much better! The implementation of the operator overload now simply forwards its arguments to the make_xfunction
generator. Besides, if later we decide to change how the xfunction
object is instantiated, we only need to modify the code of make_xfunction
instead of going through all the operator overloads and make the required changes.
Implementing the remaining arithmetic operators is as simple, only the functor type differs (std::minus
, std::multiply
and std::divide
). However, the STL does not provide functors for mathematical functions (std::exp
, std::log
, …), therefore we have to implement ours. Besides, we will use the functors to implement vectorized operations based on xsimd in the near future, so it might be worth implementing our arithmetic functors.
Arithmetic functors
Let’s start with a simple addition functor, really close to std::plus
from the STL:
namespace xt
{
struct plus
{
template <class T1, class T2>
constexpr auto operator()(const T1& t1, const T2& t2) const
{
return t1 + t2;
}
};
}
The main difference with std::plus
is that our functor is not a template class. Notice that this is the direction taken by the C++ standard too: the template parameter of std::plus
is defaulted to void
in C++14 and the inner types are deprecated in C++17 and removed in C++20 (as you can see here).
The reason for this is mixed arithmetic. With the historical implementation in the STL, both operands must have the same type. This means that type promotion occurs before the call if that possible; otherwise the type substitution fails, resulting in a compilation error.
However, there are cases where the type substitution fails but the operation is actually valid. Consider the following numerical class:
class special_double
{
public: explicit special_double(const double& d);
special_double& operator+=(const double& d);
};special_double operator+(const special_double& lhs,
const special_double& rhs);special_double operator+(const special_double& lhs,
const double& rhs);special_double operator+(const double& lhs,
const special_double& rhs);
Writing d + sd
where d
has type double
and sd
has type special_double
is legit since we provide an overload of operator+
that accept these types. However, special_double
cannot be implicitly converted from double
since its constructor is explicit
. Therefore, any call to std::plus<special_double>::operator()
where one of the argument has type double
would fail to compile.
Implementing functors for all operators is a repetitive and cumbersome task. Most of this code is boilerplate, so we should try to get it generated by macros:
#define UNARY_OPERATOR_FUNCTOR(NAME, OP) \
struct NAME \
{ \
template <class T> \
constexpr auto operator()(const T& t) const \
{ \
return OP t; \
} \
}#define BINARY_OPERATOR_FUNCTOR(NAME, OP) \
struct NAME \
{ \
template <class T1, class T2> \
constexpr auto operator()(const T1& t1, \
const T2& t2) const \
{ \
return t1 OP t2; \
} \
}
We can now easily generate all the required functors in the xt
namespace:
namespace xt
{
UNARY_OPERATOR_FUNCTOR(identity, +); UNARY_OPERATOR_FUNCTOR(negate, -); BINARY_OPERATOR_FUNCTOR(plus, +); BINARY_OPERATOR_FUNCTOR(minus, -); BINARY_OPERATOR_FUNCTOR(multiplies, *); BINARY_OPERATOR_FUNCTOR(divides, /); BINARY_OPERATOR_FUNCTOR(modulus, %); BINARY_OPERATOR_FUNCTOR(logical_or, ||); BINARY_OPERATOR_FUNCTOR(logical_and, &&); UNARY_OPERATOR_FUNCTOR(logical_not, !); BINARY_OPERATOR_FUNCTOR(bitwise_or, |); BINARY_OPERATOR_FUNCTOR(bitwise_and, &); BINARY_OPERATOR_FUNCTOR(bitwise_xor, ^); UNARY_OPERATOR_FUNCTOR(bitwise_not, ~); BINARY_OPERATOR_FUNCTOR(left_shift, <<); BINARY_OPERATOR_FUNCTOR(right_shift, >>); BINARY_OPERATOR_FUNCTOR(less, <); BINARY_OPERATOR_FUNCTOR(less_equal, <=); BINARY_OPERATOR_FUNCTOR(greater, >); BINARY_OPERATOR_FUNCTOR(greater_equal, >=); BINARY_OPERATOR_FUNCTOR(equal_to, ==); BINARY_OPERATOR_FUNCTOR(not_equal_to, !=);
}
Mathematical functors
Things are a bit more complicated for mathematical functions. The first idea would be to implement functors similar to the previous one, but that forward their call to the mathematical functions of the STL:
#define UNARY_MATH_FUNCTOR(NAME) \
struct NAME##_fun \
{ \
template <class T> \
constexpr auto operator()(const T& t) const \
{ \
return std::NAME(t); \
} \
}UNARY_MATH_FUNCTOR(exp);
...
We can then implement the xt::exp
function that accepts N-dimensional expressions:
template <class E>
inline auto exp(E&& e)
{
return make_xfunction<xt::exp_fun>(std::forward<E>(e));
}
Calling exp
on an xarray
objects now returns an xfunction
parametrized by exp_fun
. When one tries to access an element, the functor is called with the corresponding value of the underlying array, and forwards the call to std::exp
:
xt::xarray<double> a = { .... };
// f has type xfunction<exp_fun, const xarray<double>&>
auto f = exp(a);
// internally calls exp_fun::operator()(a(1))
double res = f(1);
Sweet. Until you want to operate on xarray
objects that hold a user-defined scalar type. In that case, std::exp
will fail to compile because there is no overload of std::exp
for this new type. Remember that overloading functions in the std
namespace is undefined behavior and should be avoided.
Fortunately, C++ provides a mechanism to solve this apparent issue: Argument Dependent Lookup (ADL). To summarize, the compiler can find the definition of a function in the namespace of its argument for unqualified calls (i.e. calls to functions that are not prefixed with namespace):
namespace math
{
struct special_double { ... };
special_double exp(const special_double& sd) { ... }
};namespace xt
{
struct exp_fun
{
template <class T>
constexpr auto operator()(const T& t) const
{
using std::exp; (1)
return exp(t); (2)
}
}
}
Forget line (1) for a moment and only consider (2). When the functor is called with a special_double
object, the compiler finds the definition of exp
in the namespace math
where special_double
is defined, thanks to ADL.
But what if the functor is called with a simple double
value? Since the type double
is not defined in std
, no definition of exp
is available. The using
directive in (1) makes std::exp
available in the current scope, providing a fallback for built-in types.
We could have stopped there, unfortunately, compilers are not always fully compliant to the standard. Some functions defined in the std
namespace may have different return types, some might be missing. To provide a uniform API, we need to work around these issues.
The idea is to provide standard functions in a dedicated namespace math
:
#define UNARY_MATH_FUNCTOR(NAME) \
struct NAME##_fun \
{ \
template <class T> \
constexpr auto operator()(const T& t) const \
{ \
using math::NAME; \
return NAME(t); \
} \
}
If the implementation provided by the compiler is conformant to the standard, we use it. Otherwise, we implement our own, as illustrated with isnan
below:
namespace xt
{
namespace math
{
// All these functions are standard-compliant on all
// platforms
using std::cos;
using std::sin;
// ... // isnan might return int instead of bool in glibc
inline bool isnan(double d) { return bool(std::isnan(d)); }
// ... overloads for float and integral types
}
}
You can find the exhaustive list of functors and functions overloads in xoperation.hpp and xmath.hpp.
Conclusion
Operators and mathematical functions overloading in xtensor
are structured around three main components:
- The
xfunction
class which stores the functor and the operands. The way each operand is stored is determined by itsclosure_type
. Computation is performed upon element access. - The functors describing the operations, generated by macros to avoid the cumbersome task of repeating boilerplate code. Their implementation relies on ADL and the availability of uniform standard functions in a dedicated namespace.
- Generic operators and functions overloads that accept universal references on expression types and return
xfunction
objects instantiated with the right functor type.
So far we have an expressive API to instantiate arbitrarily complex expression trees and access their elements. Next step is to make them assignable to xarray
. This requires giving a more complete API to xfunction
, dive into the details of broadcasting, and define the concept of semantics. Next article will focus on broadcasting and xfunction
API.
More about the Series
This post is just one episode of a long series of articles: