How we wrote xtensor 6/N: Operator Overloading

In the previous article, we introduced expression templates and developed the skeleton of the xfunction class, meant to represent a node of the abstract syntax tree of a tensor expression. We left aside how the xfunction is passed the types of the operands as well as the way xfunction objects are instantiated. This article gets back to these key points of the xtensor ‘s expression system.

Closure types

Since xfunction is meant to represent a node of an abstract syntax tree, eventually it will be instantiated in the overload of an arithmetic operator or a mathematical function:

template <class E1, class E2>
inline auto operator+(const E1& op1,
const E2& op2)
{
using function_type = xfunction<std::plus<T>,
const E1&,
const E2&>;
return function_type(std::plus<>(), op1, op2);
}

where E1 and E2 can be arbitrary complex expressions (for now, xfunction parametrized by other types, or xarray). Therefore, the template arguments of xfunction are the parameter types of the operator that instantiates it. Simple, isn’t it?

Unfortunately, things are a bit more complicated. Assume you want to pass a temporary returned from a function to the operator:

xarray<double> compute_array();
// ....
auto f = compute_array() + compute_array();

The temporaries returned from compute_array are bound to const references (the parameters of operator+), which is fine since these are stack-based references. Then they are “transfered” to the xfunction object and are stored as data member references, which are NOT stack-based references. Therefore the rule that extends the lifetime of a temporary bound to a constant reference does not apply, resulting in dangling references.

We need to detect at call site whether the parameter is an lvalue or a temporary. Fortunately, C++11 introduced universal references, which allow to capture the type of a function argument and achieve perfect forwarding. Diving into the detail of universal references and perfect forwarding is beyond the scope of this article, you can find more about them on the isoccp blog.

However, we cannot use universal references to store the arguments in the xfunction object:

template <class F, class... CT>
class xfunction
{
//...
std::tuple<CT&&...> m_e;
};

Indeed, the key to universal reference is the deducing context. When instantiating the tuple type, the compiler already knows the types CT and CT&& refer to classical rvalue references.

Therefore, we need an intermediate structure which specifies how to store the operands in the xfunction object: if the argument is an lvalue reference, it should be stored as a constant reference; otherwise, the argument is either an rvalue reference or a value, and it should be stored as a value. This (the way an argument is stored in the xfunction object) is called the closure type. Two versions should be provided, one that keeps the constness of the argument, and one that always “constifies” the reference to avoid side effects. Thanks to the new traits metafunctions that were introduced in C++11, the implementation is straightforward:

template <class S>
struct closure_type
{
using underlying_type = std::conditional_t<
std::is_const<std::remove_reference_t<S>>::value,
std::add_const_t<std::decay_t<S>>,
std::decay_t<S>>;
using type = std::conditional_t<
std::is_lvalue_reference<S>::value,
underlying_type&,
underlying_type>;
};
template <class S>
using closure_type_t = typename closure_type<S>::type;
template <class S>
struct const_closure_type
{
using underlying_type = std::decay_t<S>;
using type = std::conditional_t<
std::is_lvalue_reference<S>:value,
std::add_const_t<underlying_type>&,
underlying_type>;
};
template <class S>
using const_closure_type_t = typename const_closure_type<S>::type;

We now have all that we need to implement the addition operator:

template <class E1, class E2>
inline auto operator+(E1&& e1, E2&& e2)
{
using function_type = xfunction<std::plus<>,
const_closure_type_t<E1>,
const_closure_type_t<E2>>;
return function_type(std::plus<>(),
std::forward<E1>(e1),
std::forward<E2>(e2));
}

Operator overloads

Repeating the type of xfunction for each operator overload (and later for mathematical functions) is going to be cumbersome. Since the only difference between the operators is the type of the functor, we can refactor the instantiation of the xfunction object to avoid code duplication:

template <class F, class... CT>
struct xfunction_type
{
using type = xfunction<F, const_closure_type_t<CT>...>;
};
template <class F, class... CT>
inline auto make_xfunction(CT&&... e) noexcept
{
using function_type = typename xfunction_type<F, CT...>::type;
return function_type(F(), std::forward<CT>(e)...);
}
template <class E1, class E2>
inline auto operator+(E1&& e1, E2&& e2)
{
return make_xfunction<std::plus<>>(std::forward<E1>(e1),
std::forward<E2>(e2));
}

That’s much better! The implementation of the operator overload now simply forwards its arguments to the make_xfunction generator. Besides, if later we decide to change how the xfunction object is instantiated, we only need to modify the code of make_xfunction instead of going through all the operator overloads and make the required changes.

Implementing the remaining arithmetic operators is as simple, only the functor type differs (std::minus, std::multiply and std::divide). However, the STL does not provide functors for mathematical functions (std::exp, std::log, …), therefore we have to implement ours. Besides, we will use the functors to implement vectorized operations based on xsimd in the near future, so it might be worth implementing our arithmetic functors.

Arithmetic functors

Let’s start with a simple addition functor, really close to std::plus from the STL:

namespace xt
{
struct plus
{
template <class T1, class T2>
constexpr auto operator()(const T1& t1, const T2& t2) const
{
return t1 + t2;
}
};
}

The main difference with std::plus is that our functor is not a template class. Notice that this is the direction taken by the C++ standard too: the template parameter of std::plus is defaulted to void in C++14 and the inner types are deprecated in C++17 and removed in C++20 (as you can see here).

The reason for this is mixed arithmetic. With the historical implementation in the STL, both operands must have the same type. This means that type promotion occurs before the call if that possible; otherwise the type substitution fails, resulting in a compilation error.

However, there are cases where the type substitution fails but the operation is actually valid. Consider the following numerical class:

class special_double
{
public:
explicit special_double(const double& d);
special_double& operator+=(const double& d);
};
special_double operator+(const special_double& lhs,
const special_double& rhs);
special_double operator+(const special_double& lhs,
const double& rhs);
special_double operator+(const double& lhs,
const special_double& rhs);

Writing d + sd where d has type double and sd has type special_double is legit since we provide an overload of operator+ that accept these types. However, special_double cannot be implicitly converted from double since its constructor is explicit. Therefore, any call to std::plus<special_double>::operator() where one of the argument has type double would fail to compile.

Implementing functors for all operators is a repetitive and cumbersome task. Most of this code is boilerplate, so we should try to get it generated by macros:

#define UNARY_OPERATOR_FUNCTOR(NAME, OP)           \
struct NAME \
{ \
template <class T> \
constexpr auto operator()(const T& t) const \
{ \
return OP t; \
} \
}
#define BINARY_OPERATOR_FUNCTOR(NAME, OP) \
struct NAME \
{ \
template <class T1, class T2> \
constexpr auto operator()(const T1& t1, \
const T2& t2) const \
{ \
return t1 OP t2; \
} \
}

We can now easily generate all the required functors in the xt namespace:

namespace xt
{
UNARY_OPERATOR_FUNCTOR(identity, +); UNARY_OPERATOR_FUNCTOR(negate, -); BINARY_OPERATOR_FUNCTOR(plus, +); BINARY_OPERATOR_FUNCTOR(minus, -); BINARY_OPERATOR_FUNCTOR(multiplies, *); BINARY_OPERATOR_FUNCTOR(divides, /); BINARY_OPERATOR_FUNCTOR(modulus, %); BINARY_OPERATOR_FUNCTOR(logical_or, ||); BINARY_OPERATOR_FUNCTOR(logical_and, &&); UNARY_OPERATOR_FUNCTOR(logical_not, !); BINARY_OPERATOR_FUNCTOR(bitwise_or, |); BINARY_OPERATOR_FUNCTOR(bitwise_and, &); BINARY_OPERATOR_FUNCTOR(bitwise_xor, ^); UNARY_OPERATOR_FUNCTOR(bitwise_not, ~); BINARY_OPERATOR_FUNCTOR(left_shift, <<); BINARY_OPERATOR_FUNCTOR(right_shift, >>); BINARY_OPERATOR_FUNCTOR(less, <); BINARY_OPERATOR_FUNCTOR(less_equal, <=); BINARY_OPERATOR_FUNCTOR(greater, >); BINARY_OPERATOR_FUNCTOR(greater_equal, >=); BINARY_OPERATOR_FUNCTOR(equal_to, ==); BINARY_OPERATOR_FUNCTOR(not_equal_to, !=);
}

Mathematical functors

Things are a bit more complicated for mathematical functions. The first idea would be to implement functors similar to the previous one, but that forward their call to the mathematical functions of the STL:

#define UNARY_MATH_FUNCTOR(NAME)                   \
struct NAME##_fun \
{ \
template <class T> \
constexpr auto operator()(const T& t) const \
{ \
return std::NAME(t); \
} \
}
UNARY_MATH_FUNCTOR(exp);
...

We can then implement the xt::exp function that accepts N-dimensional expressions:

template <class E>
inline auto exp(E&& e)
{
return make_xfunction<xt::exp_fun>(std::forward<E>(e));
}

Calling exp on an xarray objects now returns an xfunction parametrized by exp_fun. When one tries to access an element, the functor is called with the corresponding value of the underlying array, and forwards the call to std::exp:

xt::xarray<double> a = { .... };
// f has type xfunction<exp_fun, const xarray<double>&>
auto f = exp(a);
// internally calls exp_fun::operator()(a(1))
double res = f(1);

Sweet. Until you want to operate on xarray objects that hold a user-defined scalar type. In that case, std::exp will fail to compile because there is no overload of std::exp for this new type. Remember that overloading functions in the std namespace is undefined behavior and should be avoided.

Fortunately, C++ provides a mechanism to solve this apparent issue: Argument Dependent Lookup (ADL). To summarize, the compiler can find the definition of a function in the namespace of its argument for unqualified calls (i.e. calls to functions that are not prefixed with namespace):

namespace math
{
struct special_double { ... };
special_double exp(const special_double& sd) { ... }
};
namespace xt
{
struct exp_fun
{
template <class T>
constexpr auto operator()(const T& t) const
{
using std::exp; (1)
return exp(t); (2)
}
}
}

Forget line (1) for a moment and only consider (2). When the functor is called with a special_double object, the compiler finds the definition of exp in the namespace math where special_double is defined, thanks to ADL.

But what if the functor is called with a simple double value? Since the type double is not defined in std, no definition of exp is available. The using directive in (1) makes std::exp available in the current scope, providing a fallback for built-in types.

We could have stopped there, unfortunately, compilers are not always fully compliant to the standard. Some functions defined in the std namespace may have different return types, some might be missing. To provide a uniform API, we need to work around these issues.

The idea is to provide standard functions in a dedicated namespace math:

#define UNARY_MATH_FUNCTOR(NAME)                   \
struct NAME##_fun \
{ \
template <class T> \
constexpr auto operator()(const T& t) const \
{ \
using math::NAME; \
return NAME(t); \
} \
}

If the implementation provided by the compiler is conformant to the standard, we use it. Otherwise, we implement our own, as illustrated with isnan below:

namespace xt
{
namespace math
{
// All these functions are standard-compliant on all
// platforms
using std::cos;
using std::sin;
// ...
// isnan might return int instead of bool in glibc
inline bool isnan(double d) { return bool(std::isnan(d)); }
// ... overloads for float and integral types
}
}

You can find the exhaustive list of functors and functions overloads in xoperation.hpp and xmath.hpp.

Conclusion

Operators and mathematical functions overloading in xtensor are structured around three main components:

  • The xfunction class which stores the functor and the operands. The way each operand is stored is determined by its closure_type. Computation is performed upon element access.
  • The functors describing the operations, generated by macros to avoid the cumbersome task of repeating boilerplate code. Their implementation relies on ADL and the availability of uniform standard functions in a dedicated namespace.
  • Generic operators and functions overloads that accept universal references on expression types and return xfunction objects instantiated with the right functor type.

So far we have an expressive API to instantiate arbitrarily complex expression trees and access their elements. Next step is to make them assignable to xarray. This requires giving a more complete API to xfunction, dive into the details of broadcasting, and define the concept of semantics. Next article will focus on broadcasting and xfunction API.

More about the Series

This post is just one episode of a long series of articles:

How We Wrote Xtensor

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store