10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H 23 template<
typename Broadcast,
typename XprType>
24 struct traits<TensorBroadcastingOp<Broadcast, XprType> > :
public traits<XprType>
26 typedef typename XprType::Scalar Scalar;
27 typedef traits<XprType> XprTraits;
28 typedef typename packet_traits<Scalar>::type Packet;
29 typedef typename XprTraits::StorageKind StorageKind;
30 typedef typename XprTraits::Index Index;
31 typedef typename XprType::Nested Nested;
32 typedef typename remove_reference<Nested>::type _Nested;
33 static const int NumDimensions = XprTraits::NumDimensions;
34 static const int Layout = XprTraits::Layout;
37 template<
typename Broadcast,
typename XprType>
38 struct eval<TensorBroadcastingOp<Broadcast, XprType>,
Eigen::Dense>
40 typedef const TensorBroadcastingOp<Broadcast, XprType>& type;
43 template<
typename Broadcast,
typename XprType>
44 struct nested<TensorBroadcastingOp<Broadcast, XprType>, 1, typename eval<TensorBroadcastingOp<Broadcast, XprType> >::type>
46 typedef TensorBroadcastingOp<Broadcast, XprType> type;
53 template<
typename Broadcast,
typename XprType>
54 class TensorBroadcastingOp :
public TensorBase<TensorBroadcastingOp<Broadcast, XprType>, ReadOnlyAccessors>
57 typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Scalar Scalar;
58 typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Packet Packet;
59 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
60 typedef typename XprType::CoeffReturnType CoeffReturnType;
61 typedef typename XprType::PacketReturnType PacketReturnType;
62 typedef typename Eigen::internal::nested<TensorBroadcastingOp>::type Nested;
63 typedef typename Eigen::internal::traits<TensorBroadcastingOp>::StorageKind StorageKind;
64 typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Index Index;
66 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBroadcastingOp(
const XprType& expr,
const Broadcast& broadcast)
67 : m_xpr(expr), m_broadcast(broadcast) {}
70 const Broadcast& broadcast()
const {
return m_broadcast; }
73 const typename internal::remove_all<typename XprType::Nested>::type&
74 expression()
const {
return m_xpr; }
77 typename XprType::Nested m_xpr;
78 const Broadcast m_broadcast;
83 template<
typename Broadcast,
typename ArgType,
typename Device>
84 struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
86 typedef TensorBroadcastingOp<Broadcast, ArgType> XprType;
87 typedef typename XprType::Index Index;
88 static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
89 typedef DSizes<Index, NumDims> Dimensions;
90 typedef typename XprType::Scalar Scalar;
91 typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions;
95 PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
96 Layout = TensorEvaluator<ArgType, Device>::Layout,
99 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(
const XprType& op,
const Device& device)
100 : m_impl(op.expression(), device)
105 EIGEN_STATIC_ASSERT(NumDims > 0, YOU_MADE_A_PROGRAMMING_MISTAKE);
106 const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
107 const Broadcast& broadcast = op.broadcast();
108 for (
int i = 0; i < NumDims; ++i) {
109 eigen_assert(input_dims[i] > 0);
110 m_dimensions[i] = input_dims[i] * broadcast[i];
113 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
114 m_inputStrides[0] = 1;
115 m_outputStrides[0] = 1;
116 for (
int i = 1; i < NumDims; ++i) {
117 m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1];
118 m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1];
121 m_inputStrides[NumDims-1] = 1;
122 m_outputStrides[NumDims-1] = 1;
123 for (
int i = NumDims-2; i >= 0; --i) {
124 m_inputStrides[i] = m_inputStrides[i+1] * input_dims[i+1];
125 m_outputStrides[i] = m_outputStrides[i+1] * m_dimensions[i+1];
130 typedef typename XprType::CoeffReturnType CoeffReturnType;
131 typedef typename XprType::PacketReturnType PacketReturnType;
133 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Dimensions& dimensions()
const {
return m_dimensions; }
135 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(Scalar* ) {
136 m_impl.evalSubExprsIfNeeded(NULL);
140 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void cleanup() {
144 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffReturnType coeff(Index index)
const 146 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
147 return coeffColMajor(index);
149 return coeffRowMajor(index);
154 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffColMajor(Index index)
const 156 Index inputIndex = 0;
157 for (
int i = NumDims - 1; i > 0; --i) {
158 const Index idx = index / m_outputStrides[i];
159 if (internal::index_statically_eq<Broadcast>(i, 1)) {
160 eigen_assert(idx < m_impl.dimensions()[i]);
161 inputIndex += idx * m_inputStrides[i];
163 if (internal::index_statically_eq<InputDimensions>(i, 1)) {
164 eigen_assert(idx % m_impl.dimensions()[i] == 0);
166 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
169 index -= idx * m_outputStrides[i];
171 if (internal::index_statically_eq<Broadcast>(0, 1)) {
172 eigen_assert(index < m_impl.dimensions()[0]);
175 if (internal::index_statically_eq<InputDimensions>(0, 1)) {
176 eigen_assert(index % m_impl.dimensions()[0] == 0);
178 inputIndex += (index % m_impl.dimensions()[0]);
181 return m_impl.coeff(inputIndex);
184 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffRowMajor(Index index)
const 186 Index inputIndex = 0;
187 for (
int i = 0; i < NumDims - 1; ++i) {
188 const Index idx = index / m_outputStrides[i];
189 if (internal::index_statically_eq<Broadcast>(i, 1)) {
190 eigen_assert(idx < m_impl.dimensions()[i]);
191 inputIndex += idx * m_inputStrides[i];
193 if (internal::index_statically_eq<InputDimensions>(i, 1)) {
194 eigen_assert(idx % m_impl.dimensions()[i] == 0);
196 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
199 index -= idx * m_outputStrides[i];
201 if (internal::index_statically_eq<Broadcast>(NumDims-1, 1)) {
202 eigen_assert(index < m_impl.dimensions()[NumDims-1]);
205 if (internal::index_statically_eq<InputDimensions>(NumDims-1, 1)) {
206 eigen_assert(index % m_impl.dimensions()[NumDims-1] == 0);
208 inputIndex += (index % m_impl.dimensions()[NumDims-1]);
211 return m_impl.coeff(inputIndex);
214 template<
int LoadMode>
215 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketReturnType packet(Index index)
const 217 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
218 return packetColMajor<LoadMode>(index);
220 return packetRowMajor<LoadMode>(index);
226 template<
int LoadMode>
227 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetColMajor(Index index)
const 229 const int packetSize = internal::unpacket_traits<PacketReturnType>::size;
230 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
231 eigen_assert(index+packetSize-1 < dimensions().TotalSize());
233 const Index originalIndex = index;
235 Index inputIndex = 0;
236 for (
int i = NumDims - 1; i > 0; --i) {
237 const Index idx = index / m_outputStrides[i];
238 if (internal::index_statically_eq<Broadcast>(i, 1)) {
239 eigen_assert(idx < m_impl.dimensions()[i]);
240 inputIndex += idx * m_inputStrides[i];
242 if (internal::index_statically_eq<InputDimensions>(i, 1)) {
243 eigen_assert(idx % m_impl.dimensions()[i] == 0);
245 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
248 index -= idx * m_outputStrides[i];
251 if (internal::index_statically_eq<Broadcast>(0, 1)) {
252 eigen_assert(index < m_impl.dimensions()[0]);
253 innermostLoc = index;
255 if (internal::index_statically_eq<InputDimensions>(0, 1)) {
256 eigen_assert(index % m_impl.dimensions()[0] == 0);
259 innermostLoc = index % m_impl.dimensions()[0];
262 inputIndex += innermostLoc;
266 if (innermostLoc + packetSize <= m_impl.dimensions()[0]) {
267 return m_impl.template packet<Unaligned>(inputIndex);
269 EIGEN_ALIGN_MAX
typename internal::remove_const<CoeffReturnType>::type values[packetSize];
270 values[0] = m_impl.coeff(inputIndex);
271 for (
int i = 1; i < packetSize; ++i) {
272 values[i] = coeffColMajor(originalIndex+i);
274 PacketReturnType rslt = internal::pload<PacketReturnType>(values);
279 template<
int LoadMode>
280 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetRowMajor(Index index)
const 282 const int packetSize = internal::unpacket_traits<PacketReturnType>::size;
283 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
284 eigen_assert(index+packetSize-1 < dimensions().TotalSize());
286 const Index originalIndex = index;
288 Index inputIndex = 0;
289 for (
int i = 0; i < NumDims - 1; ++i) {
290 const Index idx = index / m_outputStrides[i];
291 if (internal::index_statically_eq<Broadcast>(i, 1)) {
292 eigen_assert(idx < m_impl.dimensions()[i]);
293 inputIndex += idx * m_inputStrides[i];
295 if (internal::index_statically_eq<InputDimensions>(i, 1)) {
296 eigen_assert(idx % m_impl.dimensions()[i] == 0);
298 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
301 index -= idx * m_outputStrides[i];
304 if (internal::index_statically_eq<Broadcast>(NumDims-1, 1)) {
305 eigen_assert(index < m_impl.dimensions()[NumDims-1]);
306 innermostLoc = index;
308 if (internal::index_statically_eq<InputDimensions>(NumDims-1, 1)) {
309 eigen_assert(index % m_impl.dimensions()[NumDims-1] == 0);
312 innermostLoc = index % m_impl.dimensions()[NumDims-1];
315 inputIndex += innermostLoc;
319 if (innermostLoc + packetSize <= m_impl.dimensions()[NumDims-1]) {
320 return m_impl.template packet<Unaligned>(inputIndex);
322 EIGEN_ALIGN_MAX
typename internal::remove_const<CoeffReturnType>::type values[packetSize];
323 values[0] = m_impl.coeff(inputIndex);
324 for (
int i = 1; i < packetSize; ++i) {
325 values[i] = coeffRowMajor(originalIndex+i);
327 PacketReturnType rslt = internal::pload<PacketReturnType>(values);
333 EIGEN_DEVICE_FUNC Scalar* data()
const {
return NULL; }
336 Dimensions m_dimensions;
337 array<Index, NumDims> m_outputStrides;
338 array<Index, NumDims> m_inputStrides;
339 TensorEvaluator<ArgType, Device> m_impl;
345 #endif // EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H Namespace containing all symbols from the Eigen library.
Definition: CXX11Meta.h:13