random_forest_hdf5_impex.hxx
|
 |
36 #ifndef VIGRA_RANDOM_FOREST_IMPEX_HDF5_HXX 37 #define VIGRA_RANDOM_FOREST_IMPEX_HDF5_HXX 40 #include "random_forest.hxx" 41 #include "hdf5impex.hxx" 47 static const char *
const rf_hdf5_options =
"_options";
48 static const char *
const rf_hdf5_ext_param =
"_ext_param";
49 static const char *
const rf_hdf5_labels =
"labels";
50 static const char *
const rf_hdf5_topology =
"topology";
51 static const char *
const rf_hdf5_parameters =
"parameters";
52 static const char *
const rf_hdf5_tree =
"Tree_";
53 static const char *
const rf_hdf5_version_group =
".";
54 static const char *
const rf_hdf5_version_tag =
"vigra_random_forest_version";
55 static const double rf_hdf5_version = 0.1;
60 VIGRA_EXPORT
void options_import_HDF5(HDF5File &, RandomForestOptions &,
63 VIGRA_EXPORT
void options_export_HDF5(HDF5File &,
const RandomForestOptions &,
66 VIGRA_EXPORT
void dt_import_HDF5(HDF5File &, detail::DecisionTree &,
69 VIGRA_EXPORT
void dt_export_HDF5(HDF5File &,
const detail::DecisionTree &,
73 void rf_import_HDF5_to_map(HDF5File & h5context, X & param,
74 const char *
const ignored_label = 0)
77 typedef typename X::map_type map_type;
78 typedef std::pair<typename map_type::iterator, bool> inserter_type;
79 typedef typename map_type::value_type value_type;
80 typedef typename map_type::mapped_type mapped_type;
82 map_type serialized_param;
83 bool ignored_seen = ignored_label == 0;
85 std::vector<std::string> names = h5context.ls();
86 std::vector<std::string>::const_iterator j;
87 for (j = names.begin(); j != names.end(); ++j)
89 if (ignored_label && *j == ignored_label)
95 inserter_type new_array
96 = serialized_param.insert(value_type(*j, mapped_type()));
98 h5context.readAndResize(*j, (*(new_array.first)).second);
100 vigra_precondition(ignored_seen,
"rf_import_HDF5_to_map(): " 101 "labels are missing.");
102 param.make_from_map(serialized_param);
106 void problemspec_import_HDF5(HDF5File & h5context, ProblemSpec<T> & param,
107 const std::string & name)
110 rf_import_HDF5_to_map(h5context, param, rf_hdf5_labels);
112 ArrayVector<T> labels;
113 h5context.readAndResize(rf_hdf5_labels, labels);
114 param.classes_(labels.begin(), labels.end());
119 void rf_export_map_to_HDF5(HDF5File & h5context,
const X & param)
121 typedef typename X::map_type map_type;
122 map_type serialized_param;
124 param.make_map(serialized_param);
125 typename map_type::const_iterator j;
126 for (j = serialized_param.begin(); j != serialized_param.end(); ++j)
127 h5context.write(j->first, j->second);
131 void problemspec_export_HDF5(HDF5File & h5context, ProblemSpec<T>
const & param,
132 const std::string & name)
134 h5context.cd_mk(name);
135 rf_export_map_to_HDF5(h5context, param);
136 h5context.write(rf_hdf5_labels, param.classes);
140 struct padded_number_string_data;
141 class VIGRA_EXPORT padded_number_string
144 padded_number_string_data* padded_number;
146 padded_number_string(
const padded_number_string &);
147 void operator=(
const padded_number_string &);
149 padded_number_string(
int n);
150 std::string operator()(
int k)
const;
151 ~padded_number_string();
154 inline std::string get_cwd(HDF5File & h5context)
156 return h5context.get_absolute_path(h5context.pwd());
175 template<
class T,
class Tag>
176 void rf_export_HDF5(
const RandomForest<T, Tag> & rf,
177 HDF5File & h5context,
178 const std::string & pathname =
"")
181 if (pathname.size()) {
182 cwd = detail::get_cwd(h5context);
183 h5context.cd_mk(pathname);
186 h5context.writeAttribute(rf_hdf5_version_group, rf_hdf5_version_tag,
189 detail::options_export_HDF5(h5context, rf.options(), rf_hdf5_options);
191 detail::problemspec_export_HDF5(h5context, rf.ext_param(),
194 int tree_count = rf.options_.tree_count_;
195 detail::padded_number_string tree_number(tree_count);
196 for (
int i = 0; i < tree_count; ++i)
197 detail::dt_export_HDF5(h5context, rf.tree(i),
198 rf_hdf5_tree + tree_number(i));
218 template<
class T,
class Tag>
219 void rf_export_HDF5(
const RandomForest<T, Tag> & rf,
220 const std::string & filename,
221 const std::string & pathname =
"")
223 HDF5File h5context(filename , HDF5File::Open);
224 rf_export_HDF5(rf, h5context, pathname);
240 template<
class T,
class Tag>
241 bool rf_import_HDF5(RandomForest<T, Tag> & rf,
242 HDF5File & h5context,
243 const std::string & pathname =
"")
246 if (pathname.size()) {
247 cwd = detail::get_cwd(h5context);
248 h5context.cd(pathname);
251 if (h5context.existsAttribute(rf_hdf5_version_group, rf_hdf5_version_tag))
254 h5context.readAttribute(rf_hdf5_version_group, rf_hdf5_version_tag,
256 vigra_precondition(read_version <= rf_hdf5_version,
257 "rf_import_HDF5(): unexpected file format version.");
260 detail::options_import_HDF5(h5context, rf.options_, rf_hdf5_options);
262 detail::problemspec_import_HDF5(h5context, rf.ext_param_,
266 std::vector<std::string> names = h5context.ls();
267 std::vector<std::string>::const_iterator j;
268 for (j = names.begin(); j != names.end(); ++j)
270 if ((*j->rbegin() ==
'/') && (*j->begin() !=
'_'))
272 rf.trees_.push_back(detail::DecisionTree(rf.ext_param_));
273 detail::dt_import_HDF5(h5context, rf.trees_.back(), *j);
294 template<
class T,
class Tag>
295 bool rf_import_HDF5(RandomForest<T, Tag> & rf,
296 const std::string & filename,
297 const std::string & pathname =
"")
299 HDF5File h5context(filename, HDF5File::OpenReadOnly);
300 return rf_import_HDF5(rf, h5context, pathname);
305 #endif // VIGRA_RANDOM_FOREST_HDF5_IMPEX_HXX Definition: accessor.hxx:43