TensorCustomOp.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
12 
13 namespace Eigen {
14 
22 namespace internal {
23 template<typename CustomUnaryFunc, typename XprType>
24 struct traits<TensorCustomUnaryOp<CustomUnaryFunc, XprType> >
25 {
26  typedef typename XprType::Scalar Scalar;
27  typedef typename packet_traits<Scalar>::type Packet;
28  typedef typename XprType::StorageKind StorageKind;
29  typedef typename XprType::Index Index;
30  typedef typename XprType::Nested Nested;
31  typedef typename remove_reference<Nested>::type _Nested;
32  static const int NumDimensions = traits<XprType>::NumDimensions;
33  static const int Layout = traits<XprType>::Layout;
34 };
35 
36 template<typename CustomUnaryFunc, typename XprType>
37 struct eval<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Eigen::Dense>
38 {
39  typedef const TensorCustomUnaryOp<CustomUnaryFunc, XprType>& type;
40 };
41 
42 template<typename CustomUnaryFunc, typename XprType>
43 struct nested<TensorCustomUnaryOp<CustomUnaryFunc, XprType> >
44 {
45  typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType> type;
46 };
47 
48 } // end namespace internal
49 
50 
51 
52 template<typename CustomUnaryFunc, typename XprType>
53 class TensorCustomUnaryOp : public TensorBase<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, ReadOnlyAccessors>
54 {
55  public:
56  typedef typename internal::traits<TensorCustomUnaryOp>::Scalar Scalar;
57  typedef typename internal::traits<TensorCustomUnaryOp>::Packet Packet;
58  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
59  typedef typename XprType::CoeffReturnType CoeffReturnType;
60  typedef typename XprType::PacketReturnType PacketReturnType;
61  typedef typename internal::nested<TensorCustomUnaryOp>::type Nested;
62  typedef typename internal::traits<TensorCustomUnaryOp>::StorageKind StorageKind;
63  typedef typename internal::traits<TensorCustomUnaryOp>::Index Index;
64 
65  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomUnaryOp(const XprType& expr, const CustomUnaryFunc& func)
66  : m_expr(expr), m_func(func) {}
67 
68  EIGEN_DEVICE_FUNC
69  const CustomUnaryFunc& func() const { return m_func; }
70 
71  EIGEN_DEVICE_FUNC
72  const typename internal::remove_all<typename XprType::Nested>::type&
73  expression() const { return m_expr; }
74 
75  protected:
76  typename XprType::Nested m_expr;
77  const CustomUnaryFunc m_func;
78 };
79 
80 
81 // Eval as rvalue
82 template<typename CustomUnaryFunc, typename XprType, typename Device>
83 struct TensorEvaluator<const TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Device>
84 {
86  typedef typename internal::traits<ArgType>::Index Index;
87  static const int NumDims = internal::traits<ArgType>::NumDimensions;
88  typedef DSizes<Index, NumDims> Dimensions;
89  typedef
90  typename internal::remove_const<typename ArgType::Scalar>::type Scalar;
91 
92  enum {
93  IsAligned = false,
94  PacketAccess = (internal::packet_traits<Scalar>::size > 1),
95  BlockAccess = false,
97  CoordAccess = false, // to be implemented
98  };
99 
100  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const ArgType& op, const Device& device)
101  : m_op(op), m_device(device), m_result(NULL)
102  {
103  m_dimensions = op.func().dimensions(op.expression());
104  }
105 
106  typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
107  typedef typename XprType::PacketReturnType PacketReturnType;
108 
109  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
110 
111  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) {
112  if (data) {
113  evalTo(data);
114  return false;
115  } else {
116  m_result = static_cast<CoeffReturnType*>(
117  m_device.allocate(dimensions().TotalSize() * sizeof(Scalar)));
118  evalTo(m_result);
119  return true;
120  }
121  }
122 
123  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
124  if (m_result != NULL) {
125  m_device.deallocate(m_result);
126  m_result = NULL;
127  }
128  }
129 
130  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
131  return m_result[index];
132  }
133 
134  template<int LoadMode>
135  EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
136  return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
137  }
138 
139  EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return m_result; }
140 
141  protected:
142  EIGEN_DEVICE_FUNC void evalTo(Scalar* data) {
144  data, m_dimensions);
145  m_op.func().eval(m_op.expression(), result, m_device);
146  }
147 
148  Dimensions m_dimensions;
149  const ArgType m_op;
150  const Device& m_device;
151  CoeffReturnType* m_result;
152 };
153 
154 
155 
163 namespace internal {
164 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
165 struct traits<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> >
166 {
167  typedef typename internal::promote_storage_type<typename LhsXprType::Scalar,
168  typename RhsXprType::Scalar>::ret Scalar;
169  typedef typename packet_traits<Scalar>::type Packet;
170  typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType,
171  typename RhsXprType::CoeffReturnType>::ret CoeffReturnType;
172  typedef typename internal::promote_storage_type<typename LhsXprType::PacketReturnType,
173  typename RhsXprType::PacketReturnType>::ret PacketReturnType;
174  typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
175  typename traits<RhsXprType>::StorageKind>::ret StorageKind;
176  typedef typename promote_index_type<typename traits<LhsXprType>::Index,
177  typename traits<RhsXprType>::Index>::type Index;
178  typedef typename LhsXprType::Nested LhsNested;
179  typedef typename RhsXprType::Nested RhsNested;
180  typedef typename remove_reference<LhsNested>::type _LhsNested;
181  typedef typename remove_reference<RhsNested>::type _RhsNested;
182  static const int NumDimensions = traits<LhsXprType>::NumDimensions;
183  static const int Layout = traits<LhsXprType>::Layout;
184 };
185 
186 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
187 struct eval<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Eigen::Dense>
188 {
190 };
191 
192 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
193 struct nested<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> >
194 {
196 };
197 
198 } // end namespace internal
199 
200 
201 
202 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
203 class TensorCustomBinaryOp : public TensorBase<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, ReadOnlyAccessors>
204 {
205  public:
206  typedef typename internal::traits<TensorCustomBinaryOp>::Scalar Scalar;
207  typedef typename internal::traits<TensorCustomBinaryOp>::Packet Packet;
208  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
209  typedef typename internal::traits<TensorCustomBinaryOp>::CoeffReturnType CoeffReturnType;
210  typedef typename internal::traits<TensorCustomBinaryOp>::PacketReturnType PacketReturnType;
211  typedef typename internal::nested<TensorCustomBinaryOp>::type Nested;
212  typedef typename internal::traits<TensorCustomBinaryOp>::StorageKind StorageKind;
213  typedef typename internal::traits<TensorCustomBinaryOp>::Index Index;
214 
215  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const CustomBinaryFunc& func)
216 
217  : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_func(func) {}
218 
219  EIGEN_DEVICE_FUNC
220  const CustomBinaryFunc& func() const { return m_func; }
221 
222  EIGEN_DEVICE_FUNC
223  const typename internal::remove_all<typename LhsXprType::Nested>::type&
224  lhsExpression() const { return m_lhs_xpr; }
225 
226  EIGEN_DEVICE_FUNC
227  const typename internal::remove_all<typename RhsXprType::Nested>::type&
228  rhsExpression() const { return m_rhs_xpr; }
229 
230  protected:
231  typename LhsXprType::Nested m_lhs_xpr;
232  typename RhsXprType::Nested m_rhs_xpr;
233  const CustomBinaryFunc m_func;
234 };
235 
236 
237 // Eval as rvalue
238 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType, typename Device>
239 struct TensorEvaluator<const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Device>
240 {
242  typedef typename internal::traits<XprType>::Index Index;
243  static const int NumDims = internal::traits<XprType>::NumDimensions;
244  typedef DSizes<Index, NumDims> Dimensions;
245  typedef typename XprType::Scalar Scalar;
246 
247  enum {
248  IsAligned = false,
249  PacketAccess = (internal::packet_traits<Scalar>::size > 1),
250  BlockAccess = false,
252  CoordAccess = false, // to be implemented
253  };
254 
255  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
256  : m_op(op), m_device(device), m_result(NULL)
257  {
258  m_dimensions = op.func().dimensions(op.lhsExpression(), op.rhsExpression());
259  }
260 
261  typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
262  typedef typename XprType::PacketReturnType PacketReturnType;
263 
264  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
265 
266  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) {
267  if (data) {
268  evalTo(data);
269  return false;
270  } else {
271  m_result = static_cast<Scalar *>(m_device.allocate(dimensions().TotalSize() * sizeof(Scalar)));
272  evalTo(m_result);
273  return true;
274  }
275  }
276 
277  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
278  if (m_result != NULL) {
279  m_device.deallocate(m_result);
280  m_result = NULL;
281  }
282  }
283 
284  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
285  return m_result[index];
286  }
287 
288  template<int LoadMode>
289  EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
290  return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
291  }
292 
293  EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return m_result; }
294 
295  protected:
296  EIGEN_DEVICE_FUNC void evalTo(Scalar* data) {
297  TensorMap<Tensor<Scalar, NumDims, Layout> > result(data, m_dimensions);
298  m_op.func().eval(m_op.lhsExpression(), m_op.rhsExpression(), result, m_device);
299  }
300 
301  Dimensions m_dimensions;
302  const XprType m_op;
303  const Device& m_device;
304  CoeffReturnType* m_result;
305 };
306 
307 
308 } // end namespace Eigen
309 
310 #endif // EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
Tensor custom class.
Definition: TensorCustomOp.h:53
Tensor custom class.
Definition: TensorCustomOp.h:203
Namespace containing all symbols from the Eigen library.
Definition: CXX11Meta.h:13
The tensor evaluator classes.
Definition: TensorEvaluator.h:28
A tensor expression mapping an existing array of data.
Definition: TensorForwardDeclarations.h:17
The tensor base class.
Definition: TensorForwardDeclarations.h:19