diff --git a/include/boost/compute/types/fundamental.hpp b/include/boost/compute/types/fundamental.hpp index c1502e32..a487b14a 100644 --- a/include/boost/compute/types/fundamental.hpp +++ b/include/boost/compute/types/fundamental.hpp @@ -44,33 +44,89 @@ typedef cl_double double_; #define BOOST_COMPUTE_MAKE_VECTOR_TYPE(scalar, size) \ BOOST_PP_CAT(BOOST_PP_CAT(::boost::compute::scalar, size), _) +namespace detail { + +// specialized vector_type base classes that provide the +// (x,y), (x,y,z,w), (s0..s7), (s0..sf) accessors +template class vector_type_desc; + +template +class vector_type_desc +{ +public: + Scalar x, y; + + Scalar& operator[](size_t i) + { + return (&x)[i]; + } + + const Scalar operator[](size_t i) const + { + return (&x)[i]; + } +}; + +template +class vector_type_desc : public vector_type_desc +{ +public: + Scalar z, w; +}; + +template +class vector_type_desc +{ +public: + Scalar s0, s1, s2, s3, s4, s5, s6, s7; + + Scalar& operator[](size_t i) + { + return (&s0)[i]; + } + + const Scalar operator[](size_t i) const + { + return (&s0)[i]; + } +}; + +template +class vector_type_desc : public vector_type_desc +{ +public: + Scalar s8, s9, sa, sb, sc, sd, se, sf; +}; + +} // end detail namespace + // vector data types template -class vector_type +class vector_type : public detail::vector_type_desc { public: typedef Scalar scalar_type; vector_type() { - + BOOST_STATIC_ASSERT(sizeof(Scalar) * N == sizeof(vector_type)); } explicit vector_type(const Scalar scalar) { for(size_t i = 0; i < N; i++) - m_value[i] = scalar; + (*this)[i] = scalar; } vector_type(const vector_type &other) { - std::memcpy(m_value, other.m_value, sizeof(m_value)); + std::memcpy(this, &other, sizeof(Scalar) * N); } vector_type& operator=(const vector_type &other) { - std::memcpy(m_value, other.m_value, sizeof(m_value)); + std::memcpy(this, &other, sizeof(Scalar) * N); return *this; } @@ -79,28 +135,15 @@ class vector_type return N; } - Scalar& operator[](size_t i) - { - return m_value[i]; - } - - Scalar operator[](size_t i) const - { - return m_value[i]; - } - bool operator==(const vector_type &other) const { - return std::memcmp(m_value, other.m_value, sizeof(m_value)) == 0; + return std::memcmp(this, &other, sizeof(Scalar) * N) == 0; } bool operator!=(const vector_type &other) const { return !(*this == other); } - -protected: - scalar_type m_value[N]; }; #define BOOST_COMPUTE_VECTOR_TYPE_CTOR_ARG_FUNCTION(z, i, _) \ @@ -108,9 +151,9 @@ class vector_type #define BOOST_COMPUTE_VECTOR_TYPE_DECLARE_CTOR_ARGS(scalar, size) \ BOOST_PP_REPEAT(size, BOOST_COMPUTE_VECTOR_TYPE_CTOR_ARG_FUNCTION, _) #define BOOST_COMPUTE_VECTOR_TYPE_ASSIGN_CTOR_ARG(z, i, _) \ - m_value[i] = BOOST_PP_CAT(arg, i); + (*this)[i] = BOOST_PP_CAT(arg, i); #define BOOST_COMPUTE_VECTOR_TYPE_ASSIGN_CTOR_SINGLE_ARG(z, i, _) \ - m_value[i] = arg; + (*this)[i] = arg; #define BOOST_COMPUTE_DECLARE_VECTOR_TYPE_CLASS(cl_scalar, size, class_name) \ class class_name : public vector_type \ diff --git a/test/test_types.cpp b/test/test_types.cpp index 3bc7cff9..3ee0e791 100644 --- a/test/test_types.cpp +++ b/test/test_types.cpp @@ -42,3 +42,56 @@ BOOST_AUTO_TEST_CASE(vector_string) stream << boost::compute::int2_(1, 2); BOOST_CHECK_EQUAL(stream.str(), std::string("int2(1, 2)")); } + +BOOST_AUTO_TEST_CASE(vector_accessors_basic) +{ + boost::compute::float4_ v; + v.x = 1; + v.y = 2; + v.z = 3; + v.w = 4; + BOOST_CHECK(v == boost::compute::float4_(1, 2, 3, 4)); +} + +BOOST_AUTO_TEST_CASE(vector_accessors_all) +{ + boost::compute::int2_ i2(1, 2); + BOOST_CHECK_EQUAL(i2.x, 1); + BOOST_CHECK_EQUAL(i2.y, 2); + + boost::compute::int4_ i4(1, 2, 3, 4); + BOOST_CHECK_EQUAL(i4.x, 1); + BOOST_CHECK_EQUAL(i4.y, 2); + BOOST_CHECK_EQUAL(i4.z, 3); + BOOST_CHECK_EQUAL(i4.w, 4); + + boost::compute::int8_ i8(1, 2, 3, 4, 5, 6, 7, 8); + BOOST_CHECK_EQUAL(i8.s0, 1); + BOOST_CHECK_EQUAL(i8.s1, 2); + BOOST_CHECK_EQUAL(i8.s2, 3); + BOOST_CHECK_EQUAL(i8.s3, 4); + BOOST_CHECK_EQUAL(i8.s4, 5); + BOOST_CHECK_EQUAL(i8.s5, 6); + BOOST_CHECK_EQUAL(i8.s6, 7); + BOOST_CHECK_EQUAL(i8.s7, 8); + + boost::compute::int16_ i16( + 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16); + BOOST_CHECK_EQUAL(i16.s0, 1); + BOOST_CHECK_EQUAL(i16.s1, 2); + BOOST_CHECK_EQUAL(i16.s2, 3); + BOOST_CHECK_EQUAL(i16.s3, 4); + BOOST_CHECK_EQUAL(i16.s4, 5); + BOOST_CHECK_EQUAL(i16.s5, 6); + BOOST_CHECK_EQUAL(i16.s6, 7); + BOOST_CHECK_EQUAL(i16.s7, 8); + BOOST_CHECK_EQUAL(i16.s8, 9); + BOOST_CHECK_EQUAL(i16.s9, 10); + BOOST_CHECK_EQUAL(i16.sa, 11); + BOOST_CHECK_EQUAL(i16.sb, 12); + BOOST_CHECK_EQUAL(i16.sc, 13); + BOOST_CHECK_EQUAL(i16.sd, 14); + BOOST_CHECK_EQUAL(i16.se, 15); + BOOST_CHECK_EQUAL(i16.sf, 16); +}