// File: crn_tree_clusterizer.h // See Copyright Notice and license at the end of inc/crnlib.h #pragma once #include "crn_matrix.h" namespace crnlib { template class tree_clusterizer { public: tree_clusterizer() : m_overall_variance(0.0f) { } void clear() { m_hist.clear(); m_codebook.clear(); m_nodes.clear(); m_overall_variance = 0.0f; } void add_training_vec(const VectorType& v, uint weight) { const std::pair insert_result( m_hist.insert( std::make_pair(v, 0U) ) ); typename vector_map_type::iterator it(insert_result.first); uint max_weight = UINT_MAX - weight; if (weight > max_weight) it->second = UINT_MAX; else it->second = it->second + weight; } bool generate_codebook(uint max_size) { if (m_hist.empty()) return false; double ttsum = 0.0f; vq_node root; root.m_vectors.reserve(static_cast(m_hist.size())); for (typename vector_map_type::const_iterator it = m_hist.begin(); it != m_hist.end(); ++it) { const VectorType& v = it->first; const uint weight = it->second; root.m_centroid += (v * (float)weight); root.m_total_weight += weight; root.m_vectors.push_back( std::make_pair(v, weight) ); ttsum += v.dot(v) * weight; } root.m_variance = (float)(ttsum - (root.m_centroid.dot(root.m_centroid) / root.m_total_weight)); root.m_centroid *= (1.0f / root.m_total_weight); m_nodes.clear(); m_nodes.reserve(max_size * 2 + 1); m_nodes.push_back(root); // Warning: if this code is NOT compiled with -fno-strict-aliasing, m_nodes.get_ptr() can be NULL here. (Argh!) uint total_leaves = 1; while (total_leaves < max_size) { int worst_node_index = -1; float worst_variance = -1.0f; for (uint i = 0; i < m_nodes.size(); i++) { vq_node& node = m_nodes[i]; // Skip internal and unsplittable nodes. if ((node.m_left != -1) || (node.m_unsplittable)) continue; if (node.m_variance > worst_variance) { worst_variance = node.m_variance; worst_node_index = i; } } if (worst_variance <= 0.0f) break; split_node(worst_node_index); total_leaves++; } m_codebook.clear(); m_overall_variance = 0.0f; for (uint i = 0; i < m_nodes.size(); i++) { vq_node& node = m_nodes[i]; if (node.m_left != -1) { CRNLIB_ASSERT(node.m_right != -1); continue; } CRNLIB_ASSERT((node.m_left == -1) && (node.m_right == -1)); node.m_codebook_index = m_codebook.size(); m_codebook.push_back(node.m_centroid); m_overall_variance += node.m_variance; } return true; } inline float get_overall_variance() const { return m_overall_variance; } inline uint get_codebook_size() const { return m_codebook.size(); } inline const VectorType& get_codebook_entry(uint index) const { return m_codebook[index]; } typedef crnlib::vector vector_vec_type; inline const vector_vec_type& get_codebook() const { return m_codebook; } uint find_best_codebook_entry(const VectorType& v) const { uint cur_node_index = 0; for ( ; ; ) { const vq_node& cur_node = m_nodes[cur_node_index]; if (cur_node.m_left == -1) return cur_node.m_codebook_index; const vq_node& left_node = m_nodes[cur_node.m_left]; const vq_node& right_node = m_nodes[cur_node.m_right]; float left_dist = left_node.m_centroid.squared_distance(v); float right_dist = right_node.m_centroid.squared_distance(v); if (left_dist < right_dist) cur_node_index = cur_node.m_left; else cur_node_index = cur_node.m_right; } } uint find_best_codebook_entry_fs(const VectorType& v) const { float best_dist = math::cNearlyInfinite; uint best_index = 0; for (uint i = 0; i < m_codebook.size(); i++) { float dist = m_codebook[i].squared_distance(v); if (dist < best_dist) { best_dist = dist; best_index = i; if (best_dist == 0.0f) break; } } return best_index; } private: typedef std::map vector_map_type; vector_map_type m_hist; struct vq_node { vq_node() : m_centroid(cClear), m_total_weight(0), m_left(-1), m_right(-1), m_codebook_index(-1), m_unsplittable(false) { } VectorType m_centroid; uint64 m_total_weight; float m_variance; crnlib::vector< std::pair > m_vectors; int m_left; int m_right; int m_codebook_index; bool m_unsplittable; }; typedef crnlib::vector node_vec_type; node_vec_type m_nodes; vector_vec_type m_codebook; float m_overall_variance; random m_rand; void split_node(uint index) { vq_node& parent_node = m_nodes[index]; if (parent_node.m_vectors.size() == 1) return; VectorType furthest(0); double furthest_dist = -1.0f; for (uint i = 0; i < parent_node.m_vectors.size(); i++) { const VectorType& v = parent_node.m_vectors[i].first; double dist = v.squared_distance(parent_node.m_centroid); if (dist > furthest_dist) { furthest_dist = dist; furthest = v; } } VectorType opposite; double opposite_dist = -1.0f; for (uint i = 0; i < parent_node.m_vectors.size(); i++) { const VectorType& v = parent_node.m_vectors[i].first; double dist = v.squared_distance(furthest); if (dist > opposite_dist) { opposite_dist = dist; opposite = v; } } VectorType left_child((furthest + parent_node.m_centroid) * .5f); VectorType right_child((opposite + parent_node.m_centroid) * .5f); if (parent_node.m_vectors.size() > 2) { const uint N = VectorType::num_elements; matrix covar; covar.clear(); for (uint i = 0; i < parent_node.m_vectors.size(); i++) { const VectorType v(parent_node.m_vectors[i].first - parent_node.m_centroid); const VectorType w(v * (float)parent_node.m_vectors[i].second); for (uint x = 0; x < N; x++) for (uint y = x; y < N; y++) covar[x][y] = covar[x][y] + v[x] * w[y]; } if (N > 1) { //for (uint x = 0; x < (N - 1); x++) for (uint x = 0; x != (N - 1); x++) for (uint y = x + 1; y < N; y++) covar[y][x] = covar[x][y]; } covar /= float(parent_node.m_total_weight); VectorType axis(1.0f); // Starting with an estimate of the principle axis should work better, but doesn't in practice? //left_child - right_child); //axis.normalize(); for (uint iter = 0; iter < 10; iter++) { VectorType x; double max_sum = 0; for (uint i = 0; i < N; i++) { double sum = 0; for (uint j = 0; j < N; j++) sum += axis[j] * covar[i][j]; x[i] = (float)sum; max_sum = i ? math::maximum(max_sum, sum) : sum; } if (max_sum != 0.0f) x *= (float)(1.0f / max_sum); axis = x; } axis.normalize(); VectorType new_left_child(0.0f); VectorType new_right_child(0.0f); double left_weight = 0.0f; double right_weight = 0.0f; for (uint i = 0; i < parent_node.m_vectors.size(); i++) { const float weight = (float)parent_node.m_vectors[i].second; const VectorType& v = parent_node.m_vectors[i].first; double t = (v - parent_node.m_centroid) * axis; if (t < 0.0f) { new_left_child += v * weight; left_weight += weight; } else { new_right_child += v * weight; right_weight += weight; } } if ((left_weight > 0.0f) && (right_weight > 0.0f)) { left_child = new_left_child * (float)(1.0f/left_weight); right_child = new_right_child * (float)(1.0f/right_weight); } } uint64 left_weight = 0; uint64 right_weight = 0; crnlib::vector< std::pair > left_children; crnlib::vector< std::pair > right_children; left_children.reserve(parent_node.m_vectors.size() / 2); right_children.reserve(parent_node.m_vectors.size() / 2); float prev_total_variance = 1e+10f; float left_variance = 0.0f; float right_variance = 0.0f; // FIXME: Excessive upper limit const uint cMaxLoops = 1024; for (uint total_loops = 0; total_loops < cMaxLoops; total_loops++) { left_children.resize(0); right_children.resize(0); VectorType new_left_child(cClear); VectorType new_right_child(cClear); double left_ttsum = 0.0f; double right_ttsum = 0.0f; left_weight = 0; right_weight = 0; for (uint i = 0; i < parent_node.m_vectors.size(); i++) { const VectorType& v = parent_node.m_vectors[i].first; const uint weight = parent_node.m_vectors[i].second; double left_dist2 = left_child.squared_distance(v); double right_dist2 = right_child.squared_distance(v); if (left_dist2 < right_dist2) { left_children.push_back(parent_node.m_vectors[i]); new_left_child += (v * (float)weight); left_weight += weight; left_ttsum += v.dot(v) * weight; } else { right_children.push_back(parent_node.m_vectors[i]); new_right_child += (v * (float)weight); right_weight += weight; right_ttsum += v.dot(v) * weight; } } if ((!left_weight) || (!right_weight)) { parent_node.m_unsplittable = true; return; } left_variance = (float)(left_ttsum - (new_left_child.dot(new_left_child) / left_weight)); right_variance = (float)(right_ttsum - (new_right_child.dot(new_right_child) / right_weight)); new_left_child *= (1.0f / left_weight); new_right_child *= (1.0f / right_weight); left_child = new_left_child; left_weight = left_weight; right_child = new_right_child; right_weight = right_weight; float total_variance = left_variance + right_variance; if (total_variance < .00001f) break; if (((prev_total_variance - total_variance) / total_variance) < .00001f) break; prev_total_variance = total_variance; } const uint left_child_index = m_nodes.size(); const uint right_child_index = m_nodes.size() + 1; parent_node.m_left = m_nodes.size(); parent_node.m_right = m_nodes.size() + 1; m_nodes.resize(m_nodes.size() + 2); // parent_node is invalid now, because m_nodes has been changed vq_node& left_child_node = m_nodes[left_child_index]; vq_node& right_child_node = m_nodes[right_child_index]; left_child_node.m_centroid = left_child; left_child_node.m_total_weight = left_weight; left_child_node.m_vectors.swap(left_children); left_child_node.m_variance = left_variance; right_child_node.m_centroid = right_child; right_child_node.m_total_weight = right_weight; right_child_node.m_vectors.swap(right_children); right_child_node.m_variance = right_variance; } }; } // namespace crnlib