TensorExpr.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
12 
13 namespace Eigen {
14 
30 namespace internal {
31 template<typename NullaryOp, typename XprType>
32 struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> >
33  : traits<XprType>
34 {
35  typedef typename XprType::Packet Packet;
36  typedef traits<XprType> XprTraits;
37  typedef typename XprType::Scalar Scalar;
38  typedef typename XprType::Nested XprTypeNested;
39  typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
40  static const int NumDimensions = XprTraits::NumDimensions;
41  static const int Layout = XprTraits::Layout;
42 
43  enum {
44  Flags = 0,
45  };
46 };
47 
48 } // end namespace internal
49 
50 
51 
52 template<typename NullaryOp, typename XprType>
53 class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, XprType>, ReadOnlyAccessors>
54 {
55  public:
56  typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Scalar Scalar;
57  typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Packet Packet;
58  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
59  typedef typename XprType::CoeffReturnType CoeffReturnType;
60  typedef typename XprType::PacketReturnType PacketReturnType;
61  typedef TensorCwiseNullaryOp<NullaryOp, XprType> Nested;
62  typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::StorageKind StorageKind;
63  typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Index Index;
64 
65  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const XprType& xpr, const NullaryOp& func = NullaryOp())
66  : m_xpr(xpr), m_functor(func) {}
67 
68  EIGEN_DEVICE_FUNC
69  const typename internal::remove_all<typename XprType::Nested>::type&
70  nestedExpression() const { return m_xpr; }
71 
72  EIGEN_DEVICE_FUNC
73  const NullaryOp& functor() const { return m_functor; }
74 
75  protected:
76  typename XprType::Nested m_xpr;
77  const NullaryOp m_functor;
78 };
79 
80 
81 
82 namespace internal {
83 template<typename UnaryOp, typename XprType>
84 struct traits<TensorCwiseUnaryOp<UnaryOp, XprType> >
85  : traits<XprType>
86 {
87  // TODO(phli): Add InputScalar, InputPacket. Check references to
88  // current Scalar/Packet to see if the intent is Input or Output.
89  typedef typename result_of<UnaryOp(typename XprType::Scalar)>::type Scalar;
90  typedef traits<XprType> XprTraits;
91  typedef typename internal::packet_traits<Scalar>::type Packet;
92  typedef typename XprType::Nested XprTypeNested;
93  typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
94  static const int NumDimensions = XprTraits::NumDimensions;
95  static const int Layout = XprTraits::Layout;
96 };
97 
98 template<typename UnaryOp, typename XprType>
99 struct eval<TensorCwiseUnaryOp<UnaryOp, XprType>, Eigen::Dense>
100 {
101  typedef const TensorCwiseUnaryOp<UnaryOp, XprType>& type;
102 };
103 
104 template<typename UnaryOp, typename XprType>
105 struct nested<TensorCwiseUnaryOp<UnaryOp, XprType>, 1, typename eval<TensorCwiseUnaryOp<UnaryOp, XprType> >::type>
106 {
107  typedef TensorCwiseUnaryOp<UnaryOp, XprType> type;
108 };
109 
110 } // end namespace internal
111 
112 
113 
114 template<typename UnaryOp, typename XprType>
115 class TensorCwiseUnaryOp : public TensorBase<TensorCwiseUnaryOp<UnaryOp, XprType>, ReadOnlyAccessors>
116 {
117  public:
118  // TODO(phli): Add InputScalar, InputPacket. Check references to
119  // current Scalar/Packet to see if the intent is Input or Output.
120  typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Scalar Scalar;
121  typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Packet Packet;
122  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
123  typedef Scalar CoeffReturnType;
124  typedef typename internal::packet_traits<CoeffReturnType>::type PacketReturnType;
125  typedef typename Eigen::internal::nested<TensorCwiseUnaryOp>::type Nested;
126  typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::StorageKind StorageKind;
127  typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Index Index;
128 
129  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseUnaryOp(const XprType& xpr, const UnaryOp& func = UnaryOp())
130  : m_xpr(xpr), m_functor(func) {}
131 
132  EIGEN_DEVICE_FUNC
133  const UnaryOp& functor() const { return m_functor; }
134 
136  EIGEN_DEVICE_FUNC
137  const typename internal::remove_all<typename XprType::Nested>::type&
138  nestedExpression() const { return m_xpr; }
139 
140  protected:
141  typename XprType::Nested m_xpr;
142  const UnaryOp m_functor;
143 };
144 
145 
146 namespace internal {
147 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
148 struct traits<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >
149 {
150  // Type promotion to handle the case where the types of the lhs and the rhs
151  // are different.
152  // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket. Check references to
153  // current Scalar/Packet to see if the intent is Inputs or Output.
154  typedef typename result_of<
155  BinaryOp(typename LhsXprType::Scalar,
156  typename RhsXprType::Scalar)>::type Scalar;
157  typedef traits<LhsXprType> XprTraits;
158  typedef typename internal::packet_traits<Scalar>::type Packet;
159  typedef typename promote_storage_type<
160  typename traits<LhsXprType>::StorageKind,
161  typename traits<RhsXprType>::StorageKind>::ret StorageKind;
162  typedef typename promote_index_type<
163  typename traits<LhsXprType>::Index,
164  typename traits<RhsXprType>::Index>::type Index;
165  typedef typename LhsXprType::Nested LhsNested;
166  typedef typename RhsXprType::Nested RhsNested;
167  typedef typename remove_reference<LhsNested>::type _LhsNested;
168  typedef typename remove_reference<RhsNested>::type _RhsNested;
169  static const int NumDimensions = XprTraits::NumDimensions;
170  static const int Layout = XprTraits::Layout;
171 
172  enum {
173  Flags = 0,
174  };
175 };
176 
177 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
178 struct eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, Eigen::Dense>
179 {
180  typedef const TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>& type;
181 };
182 
183 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
184 struct nested<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, 1, typename eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >::type>
185 {
186  typedef TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> type;
187 };
188 
189 } // end namespace internal
190 
191 
192 
193 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
194 class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, ReadOnlyAccessors>
195 {
196  public:
197  // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket. Check references to
198  // current Scalar/Packet to see if the intent is Inputs or Output.
199  typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Scalar Scalar;
200  typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Packet Packet;
201  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
202  typedef Scalar CoeffReturnType;
203  typedef typename internal::packet_traits<CoeffReturnType>::type PacketReturnType;
204  typedef typename Eigen::internal::nested<TensorCwiseBinaryOp>::type Nested;
205  typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::StorageKind StorageKind;
206  typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Index Index;
207 
208  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const BinaryOp& func = BinaryOp())
209  : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_functor(func) {}
210 
211  EIGEN_DEVICE_FUNC
212  const BinaryOp& functor() const { return m_functor; }
213 
215  EIGEN_DEVICE_FUNC
216  const typename internal::remove_all<typename LhsXprType::Nested>::type&
217  lhsExpression() const { return m_lhs_xpr; }
218 
219  EIGEN_DEVICE_FUNC
220  const typename internal::remove_all<typename RhsXprType::Nested>::type&
221  rhsExpression() const { return m_rhs_xpr; }
222 
223  protected:
224  typename LhsXprType::Nested m_lhs_xpr;
225  typename RhsXprType::Nested m_rhs_xpr;
226  const BinaryOp m_functor;
227 };
228 
229 
230 namespace internal {
231 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
232 struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
233  : traits<ThenXprType>
234 {
235  typedef typename traits<ThenXprType>::Scalar Scalar;
236  typedef traits<ThenXprType> XprTraits;
237  typedef typename packet_traits<Scalar>::type Packet;
238  typedef typename promote_storage_type<typename traits<ThenXprType>::StorageKind,
239  typename traits<ElseXprType>::StorageKind>::ret StorageKind;
240  typedef typename promote_index_type<typename traits<ElseXprType>::Index,
241  typename traits<ThenXprType>::Index>::type Index;
242  typedef typename IfXprType::Nested IfNested;
243  typedef typename ThenXprType::Nested ThenNested;
244  typedef typename ElseXprType::Nested ElseNested;
245  static const int NumDimensions = XprTraits::NumDimensions;
246  static const int Layout = XprTraits::Layout;
247 };
248 
249 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
250 struct eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, Eigen::Dense>
251 {
252  typedef const TensorSelectOp<IfXprType, ThenXprType, ElseXprType>& type;
253 };
254 
255 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
256 struct nested<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, 1, typename eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >::type>
257 {
258  typedef TensorSelectOp<IfXprType, ThenXprType, ElseXprType> type;
259 };
260 
261 } // end namespace internal
262 
263 
264 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
265 class TensorSelectOp : public TensorBase<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
266 {
267  public:
268  typedef typename Eigen::internal::traits<TensorSelectOp>::Scalar Scalar;
269  typedef typename Eigen::internal::traits<TensorSelectOp>::Packet Packet;
270  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
271  typedef typename internal::promote_storage_type<typename ThenXprType::CoeffReturnType,
272  typename ElseXprType::CoeffReturnType>::ret CoeffReturnType;
273  typedef typename internal::promote_storage_type<typename ThenXprType::PacketReturnType,
274  typename ElseXprType::PacketReturnType>::ret PacketReturnType;
275  typedef typename Eigen::internal::nested<TensorSelectOp>::type Nested;
276  typedef typename Eigen::internal::traits<TensorSelectOp>::StorageKind StorageKind;
277  typedef typename Eigen::internal::traits<TensorSelectOp>::Index Index;
278 
279  EIGEN_DEVICE_FUNC
280  TensorSelectOp(const IfXprType& a_condition,
281  const ThenXprType& a_then,
282  const ElseXprType& a_else)
283  : m_condition(a_condition), m_then(a_then), m_else(a_else)
284  { }
285 
286  EIGEN_DEVICE_FUNC
287  const IfXprType& ifExpression() const { return m_condition; }
288 
289  EIGEN_DEVICE_FUNC
290  const ThenXprType& thenExpression() const { return m_then; }
291 
292  EIGEN_DEVICE_FUNC
293  const ElseXprType& elseExpression() const { return m_else; }
294 
295  protected:
296  typename IfXprType::Nested m_condition;
297  typename ThenXprType::Nested m_then;
298  typename ElseXprType::Nested m_else;
299 };
300 
301 
302 } // end namespace Eigen
303 
304 #endif // EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
Namespace containing all symbols from the Eigen library.
Definition: CXX11Meta.h:13