464 lines
13 KiB
464 lines
13 KiB
![]() |
// File: crn_tree_clusterizer.h
// See Copyright Notice and license at the end of inc/crnlib.h
#pragma once
#include "crn_matrix.h"
namespace crnlib
template<typename VectorType>
class tree_clusterizer
tree_clusterizer() :
void clear()
m_overall_variance = 0.0f;
void add_training_vec(const VectorType& v, uint weight)
const std::pair<typename vector_map_type::iterator, bool> insert_result( m_hist.insert( std::make_pair(v, 0U) ) );
typename vector_map_type::iterator it(insert_result.first);
uint max_weight = UINT_MAX - weight;
if (weight > max_weight)
it->second = UINT_MAX;
it->second = it->second + weight;
bool generate_codebook(uint max_size)
if (m_hist.empty())
return false;
double ttsum = 0.0f;
vq_node root;
for (typename vector_map_type::const_iterator it = m_hist.begin(); it != m_hist.end(); ++it)
const VectorType& v = it->first;
const uint weight = it->second;
root.m_centroid += (v * (float)weight);
root.m_total_weight += weight;
root.m_vectors.push_back( std::make_pair(v, weight) );
ttsum += v.dot(v) * weight;
root.m_variance = (float)(ttsum - (root.m_centroid.dot(root.m_centroid) / root.m_total_weight));
root.m_centroid *= (1.0f / root.m_total_weight);
m_nodes.reserve(max_size * 2 + 1);
// Warning: if this code is NOT compiled with -fno-strict-aliasing, m_nodes.get_ptr() can be NULL here. (Argh!)
uint total_leaves = 1;
while (total_leaves < max_size)
int worst_node_index = -1;
float worst_variance = -1.0f;
for (uint i = 0; i < m_nodes.size(); i++)
vq_node& node = m_nodes[i];
// Skip internal and unsplittable nodes.
if ((node.m_left != -1) || (node.m_unsplittable))
if (node.m_variance > worst_variance)
worst_variance = node.m_variance;
worst_node_index = i;
if (worst_variance <= 0.0f)
m_overall_variance = 0.0f;
for (uint i = 0; i < m_nodes.size(); i++)
vq_node& node = m_nodes[i];
if (node.m_left != -1)
CRNLIB_ASSERT(node.m_right != -1);
CRNLIB_ASSERT((node.m_left == -1) && (node.m_right == -1));
node.m_codebook_index = m_codebook.size();
m_overall_variance += node.m_variance;
return true;
inline float get_overall_variance() const { return m_overall_variance; }
inline uint get_codebook_size() const
return m_codebook.size();
inline const VectorType& get_codebook_entry(uint index) const
return m_codebook[index];
typedef crnlib::vector<VectorType> vector_vec_type;
inline const vector_vec_type& get_codebook() const
return m_codebook;
uint find_best_codebook_entry(const VectorType& v) const
uint cur_node_index = 0;
for ( ; ; )
const vq_node& cur_node = m_nodes[cur_node_index];
if (cur_node.m_left == -1)
return cur_node.m_codebook_index;
const vq_node& left_node = m_nodes[cur_node.m_left];
const vq_node& right_node = m_nodes[cur_node.m_right];
float left_dist = left_node.m_centroid.squared_distance(v);
float right_dist = right_node.m_centroid.squared_distance(v);
if (left_dist < right_dist)
cur_node_index = cur_node.m_left;
cur_node_index = cur_node.m_right;
uint find_best_codebook_entry_fs(const VectorType& v) const
float best_dist = math::cNearlyInfinite;
uint best_index = 0;
for (uint i = 0; i < m_codebook.size(); i++)
float dist = m_codebook[i].squared_distance(v);
if (dist < best_dist)
best_dist = dist;
best_index = i;
if (best_dist == 0.0f)
return best_index;
typedef std::map<VectorType, uint> vector_map_type;
vector_map_type m_hist;
struct vq_node
vq_node() : m_centroid(cClear), m_total_weight(0), m_left(-1), m_right(-1), m_codebook_index(-1), m_unsplittable(false) { }
VectorType m_centroid;
uint64 m_total_weight;
float m_variance;
crnlib::vector< std::pair<VectorType, uint> > m_vectors;
int m_left;
int m_right;
int m_codebook_index;
bool m_unsplittable;
typedef crnlib::vector<vq_node> node_vec_type;
node_vec_type m_nodes;
vector_vec_type m_codebook;
float m_overall_variance;
random m_rand;
void split_node(uint index)
vq_node& parent_node = m_nodes[index];
if (parent_node.m_vectors.size() == 1)
VectorType furthest(0);
double furthest_dist = -1.0f;
for (uint i = 0; i < parent_node.m_vectors.size(); i++)
const VectorType& v = parent_node.m_vectors[i].first;
double dist = v.squared_distance(parent_node.m_centroid);
if (dist > furthest_dist)
furthest_dist = dist;
furthest = v;
VectorType opposite;
double opposite_dist = -1.0f;
for (uint i = 0; i < parent_node.m_vectors.size(); i++)
const VectorType& v = parent_node.m_vectors[i].first;
double dist = v.squared_distance(furthest);
if (dist > opposite_dist)
opposite_dist = dist;
opposite = v;
VectorType left_child((furthest + parent_node.m_centroid) * .5f);
VectorType right_child((opposite + parent_node.m_centroid) * .5f);
if (parent_node.m_vectors.size() > 2)
const uint N = VectorType::num_elements;
matrix<N, N, float> covar;
for (uint i = 0; i < parent_node.m_vectors.size(); i++)
const VectorType v(parent_node.m_vectors[i].first - parent_node.m_centroid);
const VectorType w(v * (float)parent_node.m_vectors[i].second);
for (uint x = 0; x < N; x++)
for (uint y = x; y < N; y++)
covar[x][y] = covar[x][y] + v[x] * w[y];
if (N > 1)
//for (uint x = 0; x < (N - 1); x++)
for (uint x = 0; x != (N - 1); x++)
for (uint y = x + 1; y < N; y++)
covar[y][x] = covar[x][y];
covar /= float(parent_node.m_total_weight);
VectorType axis(1.0f);
// Starting with an estimate of the principle axis should work better, but doesn't in practice?
//left_child - right_child);
for (uint iter = 0; iter < 10; iter++)
VectorType x;
double max_sum = 0;
for (uint i = 0; i < N; i++)
double sum = 0;
for (uint j = 0; j < N; j++)
sum += axis[j] * covar[i][j];
x[i] = (float)sum;
max_sum = i ? math::maximum(max_sum, sum) : sum;
if (max_sum != 0.0f)
x *= (float)(1.0f / max_sum);
axis = x;
VectorType new_left_child(0.0f);
VectorType new_right_child(0.0f);
double left_weight = 0.0f;
double right_weight = 0.0f;
for (uint i = 0; i < parent_node.m_vectors.size(); i++)
const float weight = (float)parent_node.m_vectors[i].second;
const VectorType& v = parent_node.m_vectors[i].first;
double t = (v - parent_node.m_centroid) * axis;
if (t < 0.0f)
new_left_child += v * weight;
left_weight += weight;
new_right_child += v * weight;
right_weight += weight;
if ((left_weight > 0.0f) && (right_weight > 0.0f))
left_child = new_left_child * (float)(1.0f/left_weight);
right_child = new_right_child * (float)(1.0f/right_weight);
uint64 left_weight = 0;
uint64 right_weight = 0;
crnlib::vector< std::pair<VectorType, uint> > left_children;
crnlib::vector< std::pair<VectorType, uint> > right_children;
left_children.reserve(parent_node.m_vectors.size() / 2);
right_children.reserve(parent_node.m_vectors.size() / 2);
float prev_total_variance = 1e+10f;
float left_variance = 0.0f;
float right_variance = 0.0f;
// FIXME: Excessive upper limit
const uint cMaxLoops = 1024;
for (uint total_loops = 0; total_loops < cMaxLoops; total_loops++)
VectorType new_left_child(cClear);
VectorType new_right_child(cClear);
double left_ttsum = 0.0f;
double right_ttsum = 0.0f;
left_weight = 0;
right_weight = 0;
for (uint i = 0; i < parent_node.m_vectors.size(); i++)
const VectorType& v = parent_node.m_vectors[i].first;
const uint weight = parent_node.m_vectors[i].second;
double left_dist2 = left_child.squared_distance(v);
double right_dist2 = right_child.squared_distance(v);
if (left_dist2 < right_dist2)
new_left_child += (v * (float)weight);
left_weight += weight;
left_ttsum += v.dot(v) * weight;
new_right_child += (v * (float)weight);
right_weight += weight;
right_ttsum += v.dot(v) * weight;
if ((!left_weight) || (!right_weight))
parent_node.m_unsplittable = true;
left_variance = (float)(left_ttsum - (new_left_child.dot(new_left_child) / left_weight));
right_variance = (float)(right_ttsum - (new_right_child.dot(new_right_child) / right_weight));
new_left_child *= (1.0f / left_weight);
new_right_child *= (1.0f / right_weight);
left_child = new_left_child;
left_weight = left_weight;
right_child = new_right_child;
right_weight = right_weight;
float total_variance = left_variance + right_variance;
if (total_variance < .00001f)
if (((prev_total_variance - total_variance) / total_variance) < .00001f)
prev_total_variance = total_variance;
const uint left_child_index = m_nodes.size();
const uint right_child_index = m_nodes.size() + 1;
parent_node.m_left = m_nodes.size();
parent_node.m_right = m_nodes.size() + 1;
m_nodes.resize(m_nodes.size() + 2);
// parent_node is invalid now, because m_nodes has been changed
vq_node& left_child_node = m_nodes[left_child_index];
vq_node& right_child_node = m_nodes[right_child_index];
left_child_node.m_centroid = left_child;
left_child_node.m_total_weight = left_weight;
left_child_node.m_variance = left_variance;
right_child_node.m_centroid = right_child;
right_child_node.m_total_weight = right_weight;
right_child_node.m_variance = right_variance;
} // namespace crnlib