PageRenderTime 43ms CodeModel.GetById 16ms RepoModel.GetById 0ms app.codeStats 0ms

/netlab3.3/demprior.m

http://pmtksupport.googlecode.com/
MATLAB | 407 lines | 298 code | 67 blank | 42 comment | 7 complexity | 0fc5ec86098b64e6a968c5bfb83a3d9b MD5 | raw file
Possible License(s): BSD-2-Clause, GPL-2.0, BSD-3-Clause, GPL-3.0
  1. function demprior(action);
  2. %DEMPRIOR Demonstrate sampling from a multi-parameter Gaussian prior.
  3. %
  4. % Description
  5. % This function plots the functions represented by a multi-layer
  6. % perceptron network when the weights are set to values drawn from a
  7. % Gaussian prior distribution. The parameters AW1, AB1 AW2 and AB2
  8. % control the inverse variances of the first-layer weights, the hidden
  9. % unit biases, the second-layer weights and the output unit biases
  10. % respectively. Their values can be adjusted on a logarithmic scale
  11. % using the sliders, or by typing values into the text boxes and
  12. % pressing the return key.
  13. %
  14. % See also
  15. % MLP
  16. %
  17. % Copyright (c) Ian T Nabney (1996-2001)
  18. if nargin<1,
  19. action='initialize';
  20. end;
  21. if strcmp(action,'initialize')
  22. aw1 = 0.01;
  23. ab1 = 0.1;
  24. aw2 = 1.0;
  25. ab2 = 1.0;
  26. % Create FIGURE
  27. fig=figure( ...
  28. 'Name','Sampling from a Gaussian prior', ...
  29. 'Position', [50 50 480 380], ...
  30. 'NumberTitle','off', ...
  31. 'Color', [0.8 0.8 0.8], ...
  32. 'Visible','on');
  33. % The TITLE BAR frame
  34. uicontrol(fig, ...
  35. 'Style','frame', ...
  36. 'Units','normalized', ...
  37. 'HorizontalAlignment', 'center', ...
  38. 'Position', [0.5 0.82 0.45 0.1], ...
  39. 'BackgroundColor',[0.60 0.60 0.60]);
  40. % The TITLE BAR text
  41. uicontrol(fig, ...
  42. 'Style', 'text', ...
  43. 'Units', 'normalized', ...
  44. 'BackgroundColor', [0.6 0.6 0.6], ...
  45. 'Position', [0.54 0.85 0.40 0.05], ...
  46. 'HorizontalAlignment', 'left', ...
  47. 'String', 'Sampling from a Gaussian prior');
  48. % Frames to enclose sliders
  49. uicontrol(fig, ...
  50. 'Style', 'frame', ...
  51. 'Units', 'normalized', ...
  52. 'BackgroundColor', [0.6 0.6 0.6], ...
  53. 'Position', [0.05 0.08 0.35 0.18]);
  54. uicontrol(fig, ...
  55. 'Style', 'frame', ...
  56. 'Units', 'normalized', ...
  57. 'BackgroundColor', [0.6 0.6 0.6], ...
  58. 'Position', [0.05 0.3 0.35 0.18]);
  59. uicontrol(fig, ...
  60. 'Style', 'frame', ...
  61. 'Units', 'normalized', ...
  62. 'BackgroundColor', [0.6 0.6 0.6], ...
  63. 'Position', [0.05 0.52 0.35 0.18]);
  64. uicontrol(fig, ...
  65. 'Style', 'frame', ...
  66. 'Units', 'normalized', ...
  67. 'BackgroundColor', [0.6 0.6 0.6], ...
  68. 'Position', [0.05 0.74 0.35 0.18]);
  69. % Frame text
  70. uicontrol(fig, ...
  71. 'Style', 'text', ...
  72. 'Units', 'normalized', ...
  73. 'HorizontalAlignment', 'left', ...
  74. 'BackgroundColor', [0.6 0.6 0.6], ...
  75. 'Position', [0.07 0.17 0.06 0.07], ...
  76. 'String', 'aw1');
  77. % Frame text
  78. uicontrol(fig, ...
  79. 'Style', 'text', ...
  80. 'Units', 'normalized', ...
  81. 'HorizontalAlignment', 'left', ...
  82. 'BackgroundColor', [0.6 0.6 0.6], ...
  83. 'Position', [0.07 0.39 0.06 0.07], ...
  84. 'String', 'ab1');
  85. % Frame text
  86. uicontrol(fig, ...
  87. 'Style', 'text', ...
  88. 'Units', 'normalized', ...
  89. 'HorizontalAlignment', 'left', ...
  90. 'BackgroundColor', [0.6 0.6 0.6], ...
  91. 'Position', [0.07 0.61 0.06 0.07], ...
  92. 'String', 'aw2');
  93. % Frame text
  94. uicontrol(fig, ...
  95. 'Style', 'text', ...
  96. 'Units', 'normalized', ...
  97. 'HorizontalAlignment', 'left', ...
  98. 'BackgroundColor', [0.6 0.6 0.6], ...
  99. 'Position', [0.07 0.83 0.06 0.07], ...
  100. 'String', 'ab2');
  101. % Slider
  102. minval = -5; maxval = 5;
  103. aw1slide = uicontrol(fig, ...
  104. 'Style', 'slider', ...
  105. 'Units', 'normalized', ...
  106. 'Value', log10(aw1), ...
  107. 'BackgroundColor', [0.8 0.8 0.8], ...
  108. 'Position', [0.07 0.1 0.31 0.05], ...
  109. 'Min', minval, 'Max', maxval, ...
  110. 'Callback', 'demprior update');
  111. % Slider
  112. ab1slide = uicontrol(fig, ...
  113. 'Style', 'slider', ...
  114. 'Units', 'normalized', ...
  115. 'Value', log10(ab1), ...
  116. 'BackgroundColor', [0.8 0.8 0.8], ...
  117. 'Position', [0.07 0.32 0.31 0.05], ...
  118. 'Min', minval, 'Max', maxval, ...
  119. 'Callback', 'demprior update');
  120. % Slider
  121. aw2slide = uicontrol(fig, ...
  122. 'Style', 'slider', ...
  123. 'Units', 'normalized', ...
  124. 'Value', log10(aw2), ...
  125. 'BackgroundColor', [0.8 0.8 0.8], ...
  126. 'Position', [0.07 0.54 0.31 0.05], ...
  127. 'Min', minval, 'Max', maxval, ...
  128. 'Callback', 'demprior update');
  129. % Slider
  130. ab2slide = uicontrol(fig, ...
  131. 'Style', 'slider', ...
  132. 'Units', 'normalized', ...
  133. 'Value', log10(ab2), ...
  134. 'BackgroundColor', [0.8 0.8 0.8], ...
  135. 'Position', [0.07 0.76 0.31 0.05], ...
  136. 'Min', minval, 'Max', maxval, ...
  137. 'Callback', 'demprior update');
  138. % The graph box
  139. haxes = axes('Position', [0.5 0.28 0.45 0.45], ...
  140. 'Units', 'normalized', ...
  141. 'Visible', 'on');
  142. % Text display of hyper-parameter values
  143. format = '%8f';
  144. aw1val = uicontrol(fig, ...
  145. 'Style', 'edit', ...
  146. 'Units', 'normalized', ...
  147. 'Position', [0.15 0.17 0.23 0.07], ...
  148. 'String', sprintf(format, aw1), ...
  149. 'Callback', 'demprior newval');
  150. ab1val = uicontrol(fig, ...
  151. 'Style', 'edit', ...
  152. 'Units', 'normalized', ...
  153. 'Position', [0.15 0.39 0.23 0.07], ...
  154. 'String', sprintf(format, ab1), ...
  155. 'Callback', 'demprior newval');
  156. aw2val = uicontrol(fig, ...
  157. 'Style', 'edit', ...
  158. 'Units', 'normalized', ...
  159. 'Position', [0.15 0.61 0.23 0.07], ...
  160. 'String', sprintf(format, aw2), ...
  161. 'Callback', 'demprior newval');
  162. ab2val = uicontrol(fig, ...
  163. 'Style', 'edit', ...
  164. 'Units', 'normalized', ...
  165. 'Position', [0.15 0.83 0.23 0.07], ...
  166. 'String', sprintf(format, ab2), ...
  167. 'Callback', 'demprior newval');
  168. % The SAMPLE button
  169. uicontrol(fig, ...
  170. 'Style','push', ...
  171. 'Units','normalized', ...
  172. 'BackgroundColor', [0.6 0.6 0.6], ...
  173. 'Position',[0.5 0.08 0.13 0.1], ...
  174. 'String','Sample', ...
  175. 'Callback','demprior replot');
  176. % The CLOSE button
  177. uicontrol(fig, ...
  178. 'Style','push', ...
  179. 'Units','normalized', ...
  180. 'BackgroundColor', [0.6 0.6 0.6], ...
  181. 'Position',[0.82 0.08 0.13 0.1], ...
  182. 'String','Close', ...
  183. 'Callback','close(gcf)');
  184. % The HELP button
  185. uicontrol(fig, ...
  186. 'Style','push', ...
  187. 'Units','normalized', ...
  188. 'BackgroundColor', [0.6 0.6 0.6], ...
  189. 'Position',[0.66 0.08 0.13 0.1], ...
  190. 'String','Help', ...
  191. 'Callback','demprior help');
  192. % Save handles to objects
  193. hndlList=[fig aw1slide ab1slide aw2slide ab2slide aw1val ab1val aw2val ...
  194. ab2val haxes];
  195. set(fig, 'UserData', hndlList);
  196. demprior('replot')
  197. elseif strcmp(action, 'update'),
  198. % Update when a slider is moved.
  199. hndlList = get(gcf, 'UserData');
  200. aw1slide = hndlList(2);
  201. ab1slide = hndlList(3);
  202. aw2slide = hndlList(4);
  203. ab2slide = hndlList(5);
  204. aw1val = hndlList(6);
  205. ab1val = hndlList(7);
  206. aw2val = hndlList(8);
  207. ab2val = hndlList(9);
  208. haxes = hndlList(10);
  209. aw1 = 10^get(aw1slide, 'Value');
  210. ab1 = 10^get(ab1slide, 'Value');
  211. aw2 = 10^get(aw2slide, 'Value');
  212. ab2 = 10^get(ab2slide, 'Value');
  213. format = '%8f';
  214. set(aw1val, 'String', sprintf(format, aw1));
  215. set(ab1val, 'String', sprintf(format, ab1));
  216. set(aw2val, 'String', sprintf(format, aw2));
  217. set(ab2val, 'String', sprintf(format, ab2));
  218. demprior('replot');
  219. elseif strcmp(action, 'newval'),
  220. % Update when text is changed.
  221. hndlList = get(gcf, 'UserData');
  222. aw1slide = hndlList(2);
  223. ab1slide = hndlList(3);
  224. aw2slide = hndlList(4);
  225. ab2slide = hndlList(5);
  226. aw1val = hndlList(6);
  227. ab1val = hndlList(7);
  228. aw2val = hndlList(8);
  229. ab2val = hndlList(9);
  230. haxes = hndlList(10);
  231. aw1 = sscanf(get(aw1val, 'String'), '%f');
  232. ab1 = sscanf(get(ab1val, 'String'), '%f');
  233. aw2 = sscanf(get(aw2val, 'String'), '%f');
  234. ab2 = sscanf(get(ab2val, 'String'), '%f');
  235. set(aw1slide, 'Value', log10(aw1));
  236. set(ab1slide, 'Value', log10(ab1));
  237. set(aw2slide, 'Value', log10(aw2));
  238. set(ab2slide, 'Value', log10(ab2));
  239. demprior('replot');
  240. elseif strcmp(action, 'replot'),
  241. % Re-sample from the prior and plot graphs.
  242. oldFigNumber=watchon;
  243. hndlList = get(gcf, 'UserData');
  244. aw1slide = hndlList(2);
  245. ab1slide = hndlList(3);
  246. aw2slide = hndlList(4);
  247. ab2slide = hndlList(5);
  248. haxes = hndlList(10);
  249. aw1 = 10^get(aw1slide, 'Value');
  250. ab1 = 10^get(ab1slide, 'Value');
  251. aw2 = 10^get(aw2slide, 'Value');
  252. ab2 = 10^get(ab2slide, 'Value');
  253. axes(haxes);
  254. cla
  255. set(gca, ...
  256. 'Box', 'on', ...
  257. 'Color', [0 0 0], ...
  258. 'XColor', [0 0 0], ...
  259. 'YColor', [0 0 0], ...
  260. 'FontSize', 14);
  261. axis([-1 1 -10 10]);
  262. set(gca,'DefaultLineLineWidth', 2);
  263. nhidden = 12;
  264. prior = mlpprior(1, nhidden, 1, aw1, ab1, aw2, ab2);
  265. xvals = -1:0.005:1;
  266. nsample = 10; % Number of samples from prior.
  267. hold on
  268. plot([-1 0; 1 0], [0 -10; 0 10], 'b--');
  269. net = mlp(1, nhidden, 1, 'linear', prior);
  270. for i = 1:nsample
  271. net = mlpinit(net, prior);
  272. yvals = mlpfwd(net, xvals');
  273. plot(xvals', yvals, 'y');
  274. end
  275. watchoff(oldFigNumber);
  276. elseif strcmp(action, 'help'),
  277. % Provide help to user.
  278. oldFigNumber=watchon;
  279. helpfig = figure('Position', [100 100 480 400], ...
  280. 'Name', 'Help', ...
  281. 'NumberTitle', 'off', ...
  282. 'Color', [0.8 0.8 0.8], ...
  283. 'Visible','on');
  284. % The HELP TITLE BAR frame
  285. uicontrol(helpfig, ...
  286. 'Style','frame', ...
  287. 'Units','normalized', ...
  288. 'HorizontalAlignment', 'center', ...
  289. 'Position', [0.05 0.82 0.9 0.1], ...
  290. 'BackgroundColor',[0.60 0.60 0.60]);
  291. % The HELP TITLE BAR text
  292. uicontrol(helpfig, ...
  293. 'Style', 'text', ...
  294. 'Units', 'normalized', ...
  295. 'BackgroundColor', [0.6 0.6 0.6], ...
  296. 'Position', [0.26 0.85 0.6 0.05], ...
  297. 'HorizontalAlignment', 'left', ...
  298. 'String', 'Help: Sampling from a Gaussian Prior');
  299. helpstr1 = strcat( ...
  300. 'This demonstration shows the effects of sampling from a Gaussian', ...
  301. ' prior over weights for a two-layer feed-forward network. The', ...
  302. ' parameters aw1, ab1, aw2 and ab2 control the inverse variances of', ...
  303. ' the first-layer weights, the hidden unit biases, the second-layer', ...
  304. ' weights and the output unit biases respectively. Their values can', ...
  305. ' be adjusted on a logarithmic scale using the sliders, or by', ...
  306. ' typing values into the text boxes and pressing the return key.', ...
  307. ' After setting these values, press the ''Sample'' button to see a', ...
  308. ' new sample from the prior. ');
  309. helpstr2 = strcat( ...
  310. 'Observe how aw1 controls the horizontal length-scale of the', ...
  311. ' variation in the functions, ab1 controls the input range over', ...
  312. ' such variations occur, aw2 sets the vertical scale of the output', ...
  313. ' and ab2 sets the vertical off-set of the output. The network has', ...
  314. ' 12 hidden units. ');
  315. hstr(1) = {helpstr1};
  316. hstr(2) = {''};
  317. hstr(3) = {helpstr2};
  318. % The HELP text
  319. helpui = uicontrol(helpfig, ...
  320. 'Style', 'edit', ...
  321. 'Units', 'normalized', ...
  322. 'ForegroundColor', [0 0 0], ...
  323. 'HorizontalAlignment', 'left', ...
  324. 'BackgroundColor', [1 1 1], ...
  325. 'Min', 0, ...
  326. 'Max', 2, ...
  327. 'Position', [0.05 0.2 0.9 0.8]);
  328. [hstrw , newpos] = textwrap(helpui, hstr, 70);
  329. set(helpui, 'String', hstrw, 'Position', [0.05, 0.2, 0.9, newpos(4)]);
  330. % The CLOSE button
  331. uicontrol(helpfig, ...
  332. 'Style','push', ...
  333. 'Units','normalized', ...
  334. 'BackgroundColor', [0.6 0.6 0.6], ...
  335. 'Position',[0.4 0.05 0.2 0.1], ...
  336. 'String','Close', ...
  337. 'Callback','close(gcf)');
  338. watchoff(oldFigNumber);
  339. end;
  340. end