Tomographer  v2.0
Tomographer C++ Framework Documentation
mhrw.h
Go to the documentation of this file.
1 /* This file is part of the Tomographer project, which is distributed under the
2  * terms of the MIT license.
3  *
4  * The MIT License (MIT)
5  *
6  * Copyright (c) 2015 ETH Zurich, Institute for Theoretical Physics, Philippe Faist
7  *
8  * Permission is hereby granted, free of charge, to any person obtaining a copy
9  * of this software and associated documentation files (the "Software"), to deal
10  * in the Software without restriction, including without limitation the rights
11  * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12  * copies of the Software, and to permit persons to whom the Software is
13  * furnished to do so, subject to the following conditions:
14  *
15  * The above copyright notice and this permission notice shall be included in
16  * all copies or substantial portions of the Software.
17  *
18  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24  * SOFTWARE.
25  */
26 
27 #ifndef _TOMOGRAPHER_MHRW_H
28 #define _TOMOGRAPHER_MHRW_H
29 
30 #include <cstddef>
31 
32 #include <limits>
33 #include <random>
34 
36 #include <tomographer2/tools/fmt.h>
40 
41 
42 
49 namespace Tomographer {
50 
51 
52 
53 
54 
55 
56 
57 
58 enum {
72 };
73 
74 
75 namespace tomo_internal {
76 
84  template<typename MHWalker, int UseFnSyntaxType>
85  struct MHRandomWalk_helper_decide_jump
86  {
90  typedef typename MHWalker::PointType PointType;
94  typedef typename MHWalker::FnValueType FnValueType;
95 
102  static inline FnValueType get_ptval(MHWalker & mhwalker, const PointType & curpt)
103  {
104  (void)mhwalker; (void)curpt;
105  tomographer_assert(0 && "UNKNOWN UseFnSyntaxType: Not implemented");
106  }
120  static inline double get_a_value(MHWalker & /*mhwalker*/, const PointType & /*newpt*/, FnValueType /*newptval*/,
121  const PointType & /*curpt*/, FnValueType /*curptval*/)
122  {
123  tomographer_assert(0 && "UNKNOWN UseFnSyntaxType: Not implemented");
124  return 0;
125  }
126  };
127 
131  template<typename MHWalker>
132  struct MHRandomWalk_helper_decide_jump<MHWalker, MHUseFnValue>
133  {
134  typedef typename MHWalker::PointType PointType;
135  typedef typename MHWalker::FnValueType FnValueType;
136 
137  static inline FnValueType get_ptval(MHWalker & mhwalker, const PointType & curpt)
138  {
139  return mhwalker.fnVal(curpt);
140  }
141  static inline double get_a_value(MHWalker & /*mhwalker*/, const PointType & /*newpt*/, double newptval,
142  const PointType & /*curpt*/, double curptval)
143  {
144  return (newptval / curptval);
145  }
146  };
147 
151  template<typename MHWalker>
152  struct MHRandomWalk_helper_decide_jump<MHWalker, MHUseFnLogValue>
153  {
154  typedef typename MHWalker::PointType PointType;
155  typedef typename MHWalker::FnValueType FnValueType;
156 
157  static inline FnValueType get_ptval(MHWalker & mhwalker, const PointType & curpt)
158  {
159  return mhwalker.fnLogVal(curpt);
160  }
161  static inline double get_a_value(MHWalker & /*mhwalker*/, const PointType & /*newpt*/, FnValueType newptval,
162  const PointType & /*curpt*/, FnValueType curptval)
163  {
164  return (newptval > curptval) ? 1.0 : exp(newptval - curptval);
165  }
166  };
167 
171  template<typename MHWalker>
172  struct MHRandomWalk_helper_decide_jump<MHWalker, MHUseFnRelativeValue>
173  {
174  typedef typename MHWalker::PointType PointType;
175  typedef int FnValueType; // dummy FnValueType
176 
177  static inline FnValueType get_ptval(MHWalker & /*mhwalker*/, const PointType & /*curpt*/)
178  {
179  return 0;
180  }
181  static inline double get_a_value(MHWalker & mhwalker, const PointType & newpt, FnValueType /*newptval*/,
182  const PointType & curpt, FnValueType /*curptval*/)
183  {
184  return mhwalker.fnRelVal(newpt, curpt);
185  }
186  };
187 };
188 
189 
190 
196 template<typename CountIntType_, typename StepRealType_>
198 {
199  typedef CountIntType_ CountIntType;
200  typedef StepRealType_ StepRealType;
201 
202  explicit MHRWParams()
203  : step_size(0), n_sweep(0), n_therm(0), n_run(0)
204  {
205  }
206  MHRWParams(StepRealType step_size_, CountIntType n_sweep_, CountIntType n_therm_, CountIntType n_run_)
207  : step_size(step_size_), n_sweep(n_sweep_), n_therm(n_therm_), n_run(n_run_)
208  {
209  }
210 
213  StepRealType step_size;
214 
217  CountIntType n_sweep;
218 
221  CountIntType n_therm;
222 
225  CountIntType n_run;
226 };
227 
228 
229 
230 template<typename CountIntType, typename StepRealType>
231 std::ostream & operator<<(std::ostream & str, const MHRWParams<CountIntType,StepRealType> & p)
232 {
233  str << "MHRWParams(step_size=" << p.step_size << ",n_sweep=" << p.n_sweep
234  << ",n_therm=" << p.n_therm << ",n_run=" << p.n_run << ")";
235  return str;
236 }
237 
238 
239 
279 template<typename Rng_, typename MHWalker_, typename MHRWStatsCollector_, typename LoggerType_,
280  typename CountIntType_ = int>
282  : public virtual Tools::NeedOwnOperatorNew<typename MHWalker_::PointType>::ProviderType
283 {
284 public:
286  typedef Rng_ Rng;
288  typedef MHWalker_ MHWalker;
290  typedef MHRWStatsCollector_ MHRWStatsCollector;
292  typedef LoggerType_ LoggerType;
295  typedef CountIntType_ CountIntType;
296 
298  typedef typename MHWalker::PointType PointType;
300  typedef typename MHWalker::StepRealType StepRealType;
301 
304 
306 #ifndef TOMOGRAPHER_PARSED_BY_DOXYGEN
307  typedef typename tomo_internal::MHRandomWalk_helper_decide_jump<MHWalker,MHWalker::UseFnSyntaxType>::FnValueType
308  FnValueType;
309 #else
310  typedef _FnValueType FnValueType;
311 #endif
312 
313  enum {
315  UseFnSyntaxType = MHWalker::UseFnSyntaxType
316  };
317 
318 private:
319  const MHRWParamsType _n;
320 
321  Rng & _rng;
322  MHWalker & _mhwalker;
323  MHRWStatsCollector & _stats;
325 
327  PointType curpt;
331  FnValueType curptval;
332 
337  CountIntType num_accepted;
341  CountIntType num_live_points;
342 
343 public:
344 
346  MHRandomWalk(StepRealType step_size, CountIntType n_sweep, CountIntType n_therm, CountIntType n_run,
347  MHWalker & mhwalker, MHRWStatsCollector & stats,
348  Rng & rng, LoggerType & logger_)
349  : _n(step_size, n_sweep, n_therm, n_run),
350  _rng(rng),
351  _mhwalker(mhwalker),
352  _stats(stats),
353  _logger(TOMO_ORIGIN, logger_),
354  curpt(),
355  curptval(),
356  num_accepted(0),
357  num_live_points(0)
358  {
359  _logger.debug([&](std::ostream & stream) {
360  stream << "constructor(). n_sweep=" << n_sweep << ", step_size=" << step_size
361  << "n_therm=" << n_therm << ", n_run=" << n_run;
362  });
363  }
365  template<typename MHRWParamsType>
366  MHRandomWalk(MHRWParamsType&& n_rw,
367  MHWalker & mhwalker, MHRWStatsCollector & stats,
368  Rng & rng, LoggerType & logger_)
369  : _n(std::forward<MHRWParamsType>(n_rw)),
370  _rng(rng),
371  _mhwalker(mhwalker),
372  _stats(stats),
373  _logger(TOMO_ORIGIN, logger_),
374  curpt(),
375  curptval(),
376  num_accepted(0),
377  num_live_points(0)
378  {
379  _logger.debug([&](std::ostream & s) { s << "constructor(). mhrw parameters = " << _n; });
380  }
381 
382  MHRandomWalk(const MHRandomWalk & other) = delete;
383 
384 
386  inline MHRWParamsType mhrwParams() const { return _n; }
387 
389  inline StepRealType stepSize() const { return _n.step_size; }
390 
392  inline CountIntType nSweep() const { return _n.n_sweep; }
394  inline CountIntType nTherm() const { return _n.n_therm; }
396  inline CountIntType nRun() const { return _n.n_run; }
397 
398 
399 
403  inline bool hasAcceptanceRatio() const
404  {
405  return (num_live_points > 0);
406  }
409  template<typename RatioType = double>
410  inline RatioType acceptanceRatio() const
411  {
412  return RatioType(num_accepted) / RatioType(num_live_points);
413  }
414 
415 
420  inline const PointType & getCurrentPoint() const
421  {
422  return curpt;
423  }
424 
432  inline const FnValueType & getCurrentPointValue() const
433  {
434  return curptval;
435  }
436 
442  inline void setCurrentPoint(const PointType& pt)
443  {
444  curpt = pt;
445  curptval = tomo_internal::MHRandomWalk_helper_decide_jump<MHWalker,UseFnSyntaxType>::get_ptval(_mhwalker, curpt);
446  _logger.longdebug([&](std::ostream & s) {
447  s << "setCurrentPoint: set internal state. Value = " << curptval << "; Point =\n" << pt << "\n";
448  });
449  }
450 
451 
452 private:
453 
456  inline void _init()
457  {
458  num_accepted = 0;
459  num_live_points = 0;
460 
461  // starting point
462  curpt = _mhwalker.startPoint();
463  curptval = tomo_internal::MHRandomWalk_helper_decide_jump<MHWalker,UseFnSyntaxType>::get_ptval(_mhwalker, curpt);
464 
465  _mhwalker.init();
466  _stats.init();
467  _logger.longdebug("_init() done.");
468  }
471  inline void _thermalizing_done()
472  {
473  _mhwalker.thermalizingDone();
474  _stats.thermalizingDone();
475  _logger.longdebug("_thermalizing_done() done.");
476  }
479  inline void _done()
480  {
481  _mhwalker.done();
482  _stats.done();
483  _logger.longdebug("_done() done.");
484  }
485 
493  inline void _move(CountIntType k, bool is_thermalizing, bool is_live_iter)
494  {
495  _logger.longdebug("_move()");
496  // The reason `step_size` is passed to jump_fn instead of leaving jump_fn itself
497  // handle the step size, is that we might in the future want to dynamically adapt the
498  // step size according to the acceptance ratio. That would have to be done in this
499  // class.
500  PointType newpt = _mhwalker.jumpFn(curpt, _n.step_size);
501 
502  FnValueType newptval;
503 
504  newptval = tomo_internal::MHRandomWalk_helper_decide_jump<MHWalker,UseFnSyntaxType>::get_ptval(_mhwalker, newpt);
505 
506  double a = tomo_internal::MHRandomWalk_helper_decide_jump<MHWalker,UseFnSyntaxType>::get_a_value(
507  _mhwalker, newpt, newptval, curpt, curptval
508  );
509 
510  // accept move?
511  bool accept = true;
512  if (a < 1.0) {
513  accept = bool( _rng()-_rng.min() <= a*(_rng.max()-_rng.min()) );
514  }
515 
516  // track acceptance ratio, except if we are thermalizing
517  if (!is_thermalizing) {
518  num_accepted += accept ? 1 : 0;
519  ++num_live_points;
520  }
521 
522  _stats.rawMove(k, is_thermalizing, is_live_iter, accept, a, newpt, newptval, curpt, curptval, *this);
523 
524  _logger.longdebug([&](std::ostream & stream) {
525  stream << (is_thermalizing?"T":"#") << std::setw(3) << k << ": " << (accept?"AC":"RJ") << " "
526  << std::setprecision(4)
527  << "a=" << std::setw(5) << a << ", newptval=" << std::setw(5) << newptval
528  << ", curptval=" << std::setw(5) << curptval << ", accept_ratio="
529  << (!is_thermalizing ? Tools::fmts("%.2g", this->acceptanceRatio()) : std::string("N/A"))
530  << Tools::streamIfPossible(curpt, "\ncurpt = ", "", "");
531  });
532 
533  if (accept) {
534  // update the internal state of the random walk
535  curpt = newpt;
536  curptval = newptval;
537  }
538  _logger.longdebug("_move() done.");
539  }
540 
545  inline void _process_sample(CountIntType k, CountIntType n)
546  {
547  _stats.processSample(k, n, curpt, curptval, *this);
548  _logger.longdebug("_process_sample() done.");
549  }
550 
551 
552 public:
553 
560  void run()
561  {
562  _init();
563 
564  CountIntType k;
565 
566  _logger.longdebug([&](std::ostream & s) {
567  s << "Starting random walk, sweep size = " << _n.n_sweep << ", step size = " << _n.step_size
568  << ", # therm sweeps = " << _n.n_therm << ", # live sweeps = " << _n.n_run;
569  });
570 
571  const CountIntType num_thermalize = _n.n_sweep * _n.n_therm;
572 
573  for (k = 0; k < num_thermalize; ++k) {
574  // calculate a candidate jump point and see if we accept the move
575  _move(k, true, false);
576  }
577 
578  _thermalizing_done();
579 
580  _logger.longdebug("Thermalizing done, starting live runs.");
581 
582  const CountIntType num_run = _n.n_sweep * _n.n_run;
583 
584  CountIntType n = 0; // number of live samples
585 
586  for (k = 0; k < num_run; ++k) {
587 
588  bool is_live_iter = ((k+1) % _n.n_sweep == 0);
589 
590  // calculate a candidate jump point and see if we accept the move
591  _move(k, false, is_live_iter);
592 
593  if (is_live_iter) {
594  _process_sample(k, n);
595  ++n;
596  }
597 
598  }
599 
600  _done();
601 
602  _logger.longdebug("Random walk completed.");
603 
604  return;
605  }
606 };
607 
608 
609 
610 } // namespace Tomographer
611 
612 
613 
614 
615 #endif
A Metropolis-Hastings Random Walk.
Definition: mhrw.h:281
Utilities for formatting strings.
MHRWParamsType mhrwParams() const
The parameters of the random walk.
Definition: mhrw.h:386
MHWalker::StepRealType StepRealType
The type of a step size of the random walk.
Definition: mhrw.h:300
MHWalker::PointType PointType
The type of a point in the random walk.
Definition: mhrw.h:298
Base namespace for the Tomographer project.
Definition: densellh.h:44
void debug(const char *fmt,...)
Generate a log message with level Logger::DEBUG (printf-like syntax)
Definition: loggers.h:1916
CountIntType n_therm
Number of thermalization sweeps.
Definition: mhrw.h:221
CountIntType n_run
Number of live sweeps.
Definition: mhrw.h:225
StepRealType stepSize() const
Get the step size of the random walk.
Definition: mhrw.h:389
MHRWStatsCollector_ MHRWStatsCollector
The stats collector type (see MHRWStatsCollector Interface)
Definition: mhrw.h:290
Provide appropriate operator new() definitions for a structure which has a member of the given stored...
void setCurrentPoint(const PointType &pt)
Force manual state of random walk.
Definition: mhrw.h:442
STL namespace.
bool hasAcceptanceRatio() const
Query whether we have any statistics about acceptance ratio. This is false, for example, during the thermalizing runs.
Definition: mhrw.h:403
Provides the MH function value at each point (see Role of UseFnSyntaxType)
Definition: mhrw.h:62
RatioType acceptanceRatio() const
Return the acceptance ratio so far.
Definition: mhrw.h:410
MHRWParams< CountIntType, StepRealType > MHRWParamsType
The struct which can hold the parameters of this random walk.
Definition: mhrw.h:303
const PointType & getCurrentPoint() const
Access the current state of the random walk.
Definition: mhrw.h:420
MHRandomWalk(MHRWParamsType &&n_rw, MHWalker &mhwalker, MHRWStatsCollector &stats, Rng &rng, LoggerType &logger_)
Simple constructor, initializes the given fields.
Definition: mhrw.h:366
MHWalker_ MHWalker
The random walker type which knows about the state space and jump function.
Definition: mhrw.h:288
StepRealType step_size
The step size of the random walk.
Definition: mhrw.h:213
#define TOMO_ORIGIN
Use this as argument for a Tomographer::Logger::LocalLogger constructor .
Definition: loggers.h:1658
_FnValueType FnValueType
The type of the Metropolis-Hastings function value. (See class documentation)
Definition: mhrw.h:310
T setw(T...args)
CountIntType nRun() const
Number of live run sweeps.
Definition: mhrw.h:396
STL class.
const FnValueType & getCurrentPointValue() const
Access the current function value of the random walk.
Definition: mhrw.h:432
MHRandomWalk(StepRealType step_size, CountIntType n_sweep, CountIntType n_therm, CountIntType n_run, MHWalker &mhwalker, MHRWStatsCollector &stats, Rng &rng, LoggerType &logger_)
Simple constructor, initializes the given fields.
Definition: mhrw.h:346
LoggerType_ LoggerType
The logger type which will be provided by user to constructor (see Logging and Loggers) ...
Definition: mhrw.h:292
void longdebug(const char *fmt,...)
Generate a log message with level Logger::LONGDEBUG (printf-like syntax)
Definition: loggers.h:1910
_Unspecified streamIfPossible(const T &obj)
Utility to stream an object, but only if "<<" overload exists.
Definition: fmt.h:307
Provides the logarithm MH function value at each point (see Role of UseFnSyntaxType) ...
Definition: mhrw.h:66
Rng_ Rng
Random number generator type (see C++ std::random)
Definition: mhrw.h:286
Some C++ utilities, with a tad of C++11 tricks.
void run()
Run the random walk. (pun intended)
Definition: mhrw.h:560
Provides directly the ratio of the function values for two consecutive points of the MH random walk (...
Definition: mhrw.h:71
Managing the need for specific overrides to operator new() for some types (especially Eigen types) ...
CountIntType n_sweep
The number of individual updates to collect together in a "sweep".
Definition: mhrw.h:217
CountIntType nTherm() const
Number of thermalizing sweeps.
Definition: mhrw.h:394
std::string fmts(const char *fmt,...)
printf- format to a std::string
Definition: fmt.h:124
Specify the parameters of a Metropolis-Hastings random walk.
Definition: mhrw.h:197
CountIntType nSweep() const
Number of iterations in a sweep.
Definition: mhrw.h:392
CountIntType_ CountIntType
The type used for counting numbers of iterations (see, e.g. nSweep() or MHRWParams) ...
Definition: mhrw.h:295
T setprecision(T...args)
Binning Analysis in a Metropolis-Hastings random walk.
STL class.
Utilities for logging messages.