From c3d13d38f481e5217aa1209f55127ac1061bab39 Mon Sep 17 00:00:00 2001 From: "jan.skoglund@webrtc.org" Date: Mon, 10 Mar 2014 22:50:19 +0000 Subject: [PATCH] Classes and tests for audio an classifier. The class can be used to classify whether a frame of audio contains speech or music. The classifier uses the music/speech classifier in Opus. R=andrew@webrtc.org, henrik.lundin@webrtc.org, turaj@webrtc.org Review URL: https://webrtc-codereview.appspot.com/5549004 git-svn-id: http://webrtc.googlecode.com/svn/trunk@5677 4adac7df-926f-26a2-2b94-8c16560cd09d --- .../audio_coding/neteq4/audio_classifier.cc | 71 ++++++++++++ .../audio_coding/neteq4/audio_classifier.h | 59 ++++++++++ .../neteq4/audio_classifier_unittest.cc | 75 +++++++++++++ webrtc/modules/audio_coding/neteq4/neteq.gypi | 14 ++- .../audio_coding/neteq4/neteq_tests.gypi | 11 ++ .../neteq4/test/audio_classifier_test.cc | 105 ++++++++++++++++++ webrtc/modules/modules.gyp | 1 + webrtc/modules/modules_unittests.isolate | 10 ++ 8 files changed, 342 insertions(+), 4 deletions(-) create mode 100644 webrtc/modules/audio_coding/neteq4/audio_classifier.cc create mode 100644 webrtc/modules/audio_coding/neteq4/audio_classifier.h create mode 100644 webrtc/modules/audio_coding/neteq4/audio_classifier_unittest.cc create mode 100644 webrtc/modules/audio_coding/neteq4/test/audio_classifier_test.cc diff --git a/webrtc/modules/audio_coding/neteq4/audio_classifier.cc b/webrtc/modules/audio_coding/neteq4/audio_classifier.cc new file mode 100644 index 000000000..a272fbce3 --- /dev/null +++ b/webrtc/modules/audio_coding/neteq4/audio_classifier.cc @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_coding/neteq4/audio_classifier.h" + +#include +#include + +namespace webrtc { + +static const int kDefaultSampleRateHz = 48000; +static const int kDefaultFrameRateHz = 50; +static const int kDefaultFrameSizeSamples = + kDefaultSampleRateHz / kDefaultFrameRateHz; +static const float kDefaultThreshold = 0.5f; + +AudioClassifier::AudioClassifier() + : analysis_info_(), + is_music_(false), + music_probability_(0), + // This actually assigns the pointer to a static constant struct + // rather than creates a struct and |celt_mode_| does not need + // to be deleted. + celt_mode_(opus_custom_mode_create(kDefaultSampleRateHz, + kDefaultFrameSizeSamples, + NULL)), + analysis_state_() { + assert(celt_mode_); +} + +AudioClassifier::~AudioClassifier() {} + +bool AudioClassifier::Analysis(const int16_t* input, + int input_length, + int channels) { + // Must be 20 ms frames at 48 kHz sampling. + assert((input_length / channels) == kDefaultFrameSizeSamples); + + // Only mono or stereo are allowed. + assert(channels == 1 || channels == 2); + + // Call Opus' classifier, defined in + // "third_party/opus/src/src/analysis.h", with lsb_depth = 16. + // Also uses a down-mixing function downmix_int, defined in + // "third_party/opus/src/src/opus_private.h", with + // constants c1 = 0, and c2 = -2. + run_analysis(&analysis_state_, + celt_mode_, + input, + kDefaultFrameSizeSamples, + kDefaultFrameSizeSamples, + 0, + -2, + channels, + kDefaultSampleRateHz, + 16, + downmix_int, + &analysis_info_); + music_probability_ = analysis_info_.music_prob; + is_music_ = music_probability_ > kDefaultThreshold; + return is_music_; +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_coding/neteq4/audio_classifier.h b/webrtc/modules/audio_coding/neteq4/audio_classifier.h new file mode 100644 index 000000000..7451d3e3c --- /dev/null +++ b/webrtc/modules/audio_coding/neteq4/audio_classifier.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_MODULES_AUDIO_CODING_NETEQ4_AUDIO_CLASSIFIER_H_ +#define WEBRTC_MODULES_AUDIO_CODING_NETEQ4_AUDIO_CLASSIFIER_H_ + +#if defined(__cplusplus) +extern "C" { +#endif +#include "third_party/opus/src/celt/celt.h" +#include "third_party/opus/src/src/analysis.h" +#include "third_party/opus/src/src/opus_private.h" +#if defined(__cplusplus) +} +#endif + +#include "webrtc/system_wrappers/interface/scoped_ptr.h" +#include "webrtc/typedefs.h" + +namespace webrtc { + +// This class provides a speech/music classification and is a wrapper over the +// Opus classifier. It currently only supports 48 kHz mono or stereo with a +// frame size of 20 ms. + +class AudioClassifier { + public: + AudioClassifier(); + virtual ~AudioClassifier(); + + // Classifies one frame of audio data in input, + // input_length : must be channels * 960; + // channels : must be 1 (mono) or 2 (stereo). + bool Analysis(const int16_t* input, int input_length, int channels); + + // Gets the current classification : true = music, false = speech. + bool is_music() const { return is_music_; } + + // Gets the current music probability. + float music_probability() const { return music_probability_; } + + private: + AnalysisInfo analysis_info_; + bool is_music_; + float music_probability_; + const CELTMode* celt_mode_; + TonalityAnalysisState analysis_state_; +}; + +} // namespace webrtc + +#endif // WEBRTC_MODULES_AUDIO_CODING_NETEQ4_AUDIO_CLASSIFIER_H_ diff --git a/webrtc/modules/audio_coding/neteq4/audio_classifier_unittest.cc b/webrtc/modules/audio_coding/neteq4/audio_classifier_unittest.cc new file mode 100644 index 000000000..0a6671816 --- /dev/null +++ b/webrtc/modules/audio_coding/neteq4/audio_classifier_unittest.cc @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_coding/neteq4/audio_classifier.h" + +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "webrtc/test/testsupport/fileutils.h" + +namespace webrtc { + +static const size_t kFrameSize = 960; + +TEST(AudioClassifierTest, AllZeroInput) { + int16_t in_mono[kFrameSize] = {0}; + + // Test all-zero vectors and let the classifier converge from its default + // to the expected value. + AudioClassifier zero_classifier; + for (int i = 0; i < 100; ++i) { + zero_classifier.Analysis(in_mono, kFrameSize, 1); + } + EXPECT_TRUE(zero_classifier.is_music()); +} + +void RunAnalysisTest(const std::string& audio_filename, + const std::string& data_filename, + size_t channels) { + AudioClassifier classifier; + scoped_ptr in(new int16_t[channels * kFrameSize]); + bool is_music_ref; + + FILE* audio_file = fopen(audio_filename.c_str(), "rb"); + ASSERT_TRUE(audio_file != NULL) << "Failed to open file " << audio_filename + << std::endl; + FILE* data_file = fopen(data_filename.c_str(), "rb"); + ASSERT_TRUE(audio_file != NULL) << "Failed to open file " << audio_filename + << std::endl; + while (fread(in.get(), sizeof(int16_t), channels * kFrameSize, audio_file) == + channels * kFrameSize) { + bool is_music = + classifier.Analysis(in.get(), channels * kFrameSize, channels); + EXPECT_EQ(is_music, classifier.is_music()); + ASSERT_EQ(1u, fread(&is_music_ref, sizeof(is_music_ref), 1, data_file)); + EXPECT_EQ(is_music_ref, is_music); + } + fclose(audio_file); + fclose(data_file); +} + +TEST(AudioClassifierTest, DoAnalysisMono) { + RunAnalysisTest(test::ResourcePath("short_mixed_mono_48", "pcm"), + test::ResourcePath("short_mixed_mono_48", "dat"), + 1); +} + +TEST(AudioClassifierTest, DoAnalysisStereo) { + RunAnalysisTest(test::ResourcePath("short_mixed_stereo_48", "pcm"), + test::ResourcePath("short_mixed_stereo_48", "dat"), + 2); +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_coding/neteq4/neteq.gypi b/webrtc/modules/audio_coding/neteq4/neteq.gypi index 46601099d..afcefbe97 100644 --- a/webrtc/modules/audio_coding/neteq4/neteq.gypi +++ b/webrtc/modules/audio_coding/neteq4/neteq.gypi @@ -16,6 +16,7 @@ 'iSAC', 'iSACFix', 'CNG', + '<(DEPTH)/third_party/opus/opus.gyp:opus', '<(webrtc_root)/common_audio/common_audio.gyp:common_audio', '<(webrtc_root)/system_wrappers/source/system_wrappers.gyp:system_wrappers', ], @@ -38,20 +39,25 @@ '<@(neteq_defines)', ], 'include_dirs': [ - 'interface', - '<(webrtc_root)', + # Need Opus header files for the audio classifier. + '<(DEPTH)/third_party/opus/src/celt', ], 'direct_dependent_settings': { 'include_dirs': [ - 'interface', - '<(webrtc_root)', + # Need Opus header files for the audio classifier. + '<(DEPTH)/third_party/opus/src/celt', ], }, + 'export_dependent_settings': [ + '<(DEPTH)/third_party/opus/opus.gyp:opus', + ], 'sources': [ 'interface/audio_decoder.h', 'interface/neteq.h', 'accelerate.cc', 'accelerate.h', + 'audio_classifier.cc', + 'audio_classifier.h', 'audio_decoder_impl.cc', 'audio_decoder_impl.h', 'audio_decoder.cc', diff --git a/webrtc/modules/audio_coding/neteq4/neteq_tests.gypi b/webrtc/modules/audio_coding/neteq4/neteq_tests.gypi index 419aefa1c..e1fcae7aa 100644 --- a/webrtc/modules/audio_coding/neteq4/neteq_tests.gypi +++ b/webrtc/modules/audio_coding/neteq4/neteq_tests.gypi @@ -140,6 +140,17 @@ ], }, + { + 'target_name': 'audio_classifier_test', + 'type': 'executable', + 'dependencies': [ + 'NetEq4', + ], + 'sources': [ + 'test/audio_classifier_test.cc', + ], + }, + { 'target_name': 'neteq4_speed_test', 'type': 'executable', diff --git a/webrtc/modules/audio_coding/neteq4/test/audio_classifier_test.cc b/webrtc/modules/audio_coding/neteq4/test/audio_classifier_test.cc new file mode 100644 index 000000000..730406bb9 --- /dev/null +++ b/webrtc/modules/audio_coding/neteq4/test/audio_classifier_test.cc @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_coding/neteq4/audio_classifier.h" + +#include +#include +#include +#include + +#include +#include + +#include "webrtc/system_wrappers/interface/scoped_ptr.h" + +int main(int argc, char* argv[]) { + if (argc != 5) { + std::cout << "Usage: " << argv[0] << + " channels output_type " + << std::endl << std::endl; + std::cout << "Where channels can be 1 (mono) or 2 (interleaved stereo),"; + std::cout << " outputs can be 1 (classification (boolean)) or 2"; + std::cout << " (classification and music probability (float))," + << std::endl; + std::cout << "and the sampling frequency is assumed to be 48 kHz." + << std::endl; + return -1; + } + + const int kFrameSizeSamples = 960; + int channels = atoi(argv[1]); + if (channels < 1 || channels > 2) { + std::cout << "Disallowed number of channels " << channels << std::endl; + return -1; + } + + int outputs = atoi(argv[2]); + if (outputs < 1 || outputs > 2) { + std::cout << "Disallowed number of outputs " << outputs << std::endl; + return -1; + } + + const int data_size = channels * kFrameSizeSamples; + webrtc::scoped_ptr in(new int16_t[data_size]); + + std::string input_filename = argv[3]; + std::string output_filename = argv[4]; + + std::cout << "Input file: " << input_filename << std::endl; + std::cout << "Output file: " << output_filename << std::endl; + + FILE* in_file = fopen(input_filename.c_str(), "rb"); + if (!in_file) { + std::cout << "Cannot open input file " << input_filename << std::endl; + return -1; + } + + FILE* out_file = fopen(output_filename.c_str(), "wb"); + if (!out_file) { + std::cout << "Cannot open output file " << output_filename << std::endl; + return -1; + } + + webrtc::AudioClassifier classifier; + int frame_counter = 0; + int music_counter = 0; + while (fread(in.get(), sizeof(*in.get()), + data_size, in_file) == (size_t) data_size) { + bool is_music = classifier.Analysis(in.get(), data_size, channels); + if (!fwrite(&is_music, sizeof(is_music), 1, out_file)) { + std::cout << "Error writing." << std::endl; + return -1; + } + if (is_music) { + music_counter++; + } + std::cout << "frame " << frame_counter << " decision " << is_music; + if (outputs == 2) { + float music_prob = classifier.music_probability(); + if (!fwrite(&music_prob, sizeof(music_prob), 1, out_file)) { + std::cout << "Error writing." << std::endl; + return -1; + } + std::cout << " music prob " << music_prob; + } + std::cout << std::endl; + frame_counter++; + } + std::cout << frame_counter << " frames processed." << std::endl; + if (frame_counter > 0) { + float music_percentage = music_counter / static_cast(frame_counter); + std::cout << music_percentage << " percent music." << std::endl; + } + + fclose(in_file); + fclose(out_file); + return 0; +} diff --git a/webrtc/modules/modules.gyp b/webrtc/modules/modules.gyp index 10da58a53..5d0827cf7 100644 --- a/webrtc/modules/modules.gyp +++ b/webrtc/modules/modules.gyp @@ -115,6 +115,7 @@ 'audio_coding/codecs/isac/fix/source/transform_unittest.cc', 'audio_coding/codecs/isac/main/source/isac_unittest.cc', 'audio_coding/codecs/opus/opus_unittest.cc', + 'audio_coding/neteq4/audio_classifier_unittest.cc', 'audio_coding/neteq4/audio_multi_vector_unittest.cc', 'audio_coding/neteq4/audio_vector_unittest.cc', 'audio_coding/neteq4/background_noise_unittest.cc', diff --git a/webrtc/modules/modules_unittests.isolate b/webrtc/modules/modules_unittests.isolate index 06f8a2dd0..e4139ba64 100644 --- a/webrtc/modules/modules_unittests.isolate +++ b/webrtc/modules/modules_unittests.isolate @@ -15,6 +15,12 @@ '../../../data/', '../../../resources/', ], + 'isolate_dependency_tracked': [ + '../../../resources/short_mixed_mono_48.dat', + '../../../resources/short_mixed_mono_48.pcm', + '../../../resources/short_mixed_stereo_48.dat', + '../../../resources/short_mixed_stereo_48.pcm', + ], }, }], ['OS=="linux" or OS=="mac" or OS=="win"', { @@ -72,6 +78,10 @@ '../../resources/remote_bitrate_estimator/VideoSendersTest_BweTest_SteadyLoss_0_TOF.bin', '../../resources/remote_bitrate_estimator/VideoSendersTest_BweTest_UnlimitedSpeed_0_AST.bin', '../../resources/remote_bitrate_estimator/VideoSendersTest_BweTest_UnlimitedSpeed_0_TOF.bin', + '../../resources/short_mixed_mono_48.dat', + '../../resources/short_mixed_mono_48.pcm', + '../../resources/short_mixed_stereo_48.dat', + '../../resources/short_mixed_stereo_48.pcm', '../../resources/sprint-downlink.rx', '../../resources/sprint-uplink.rx', '../../resources/synthetic-trace.rx',