diff --git a/modules/core/src/matop.cpp b/modules/core/src/matop.cpp index eefd63318..23b6940c8 100644 --- a/modules/core/src/matop.cpp +++ b/modules/core/src/matop.cpp @@ -200,6 +200,7 @@ public: void multiply(const MatExpr& e, double s, MatExpr& res) const; static void makeExpr(MatExpr& res, int method, Size sz, int type, double alpha=1); + static void makeExpr(MatExpr& res, int method, int ndims, const int* sizes, int type, double alpha=1); }; static MatOp_Initializer* getGlobalMatOpInitializer() @@ -1555,8 +1556,13 @@ void MatOp_Initializer::assign(const MatExpr& e, Mat& m, int _type) const { if( _type == -1 ) _type = e.a.type(); - m.create(e.a.size(), _type); - if( e.flags == 'I' ) + + if( e.a.dims <= 2 ) + m.create(e.a.size(), _type); + else + m.create(e.a.dims, e.a.size, _type); + + if( e.flags == 'I' && e.a.dims <= 2 ) setIdentity(m, Scalar(e.alpha)); else if( e.flags == '0' ) m = Scalar(); @@ -1577,6 +1583,12 @@ inline void MatOp_Initializer::makeExpr(MatExpr& res, int method, Size sz, int t res = MatExpr(getGlobalMatOpInitializer(), method, Mat(sz, type, (void*)0), Mat(), Mat(), alpha, 0); } +inline void MatOp_Initializer::makeExpr(MatExpr& res, int method, int ndims, const int* sizes, int type, double alpha) +{ + res = MatExpr(getGlobalMatOpInitializer(), method, Mat(ndims, sizes, type, (void*)0), Mat(), Mat(), alpha, 0); +} + + /////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1636,6 +1648,20 @@ MatExpr Mat::ones(Size size, int type) return e; } +MatExpr Mat::zeros(int ndims, const int* sizes, int type) +{ + MatExpr e; + MatOp_Initializer::makeExpr(e, '0', ndims, sizes, type); + return e; +} + +MatExpr Mat::ones(int ndims, const int* sizes, int type) +{ + MatExpr e; + MatOp_Initializer::makeExpr(e, '1', ndims, sizes, type); + return e; +} + MatExpr Mat::eye(int rows, int cols, int type) { MatExpr e; diff --git a/modules/core/test/test_mat.cpp b/modules/core/test/test_mat.cpp index a6ebe152d..f854abed7 100644 --- a/modules/core/test/test_mat.cpp +++ b/modules/core/test/test_mat.cpp @@ -918,3 +918,18 @@ TEST(Core_Mat, copyNx1ToVector) ASSERT_PRED_FORMAT2(cvtest::MatComparator(0, 0), ref_dst16, cv::Mat_(dst16)); } + +TEST(Core_Mat, multiDim) +{ + int d[]={3,3,3}; + Mat m0 = Mat::zeros(3,d,CV_8U); + ASSERT_EQ(0,sum(m0)[0]); + Mat m = Mat::ones(3,d,CV_8U); + ASSERT_EQ(27,sum(m)[0]); + m += 2; + ASSERT_EQ(81,sum(m)[0]); + m *= 3; + ASSERT_EQ(243,sum(m)[0]); + m += m; + ASSERT_EQ(486,sum(m)[0]); +}