PageRenderTime 23ms CodeModel.GetById 28ms RepoModel.GetById 0ms app.codeStats 0ms

/notebooks/tabular_comparison.ipynb

https://gitlab.com/rldotai/td-variance
Jupyter | 439 lines | 439 code | 0 blank | 0 comment | 0 complexity | 08e445cd0c775db7e39caca3b28abccf MD5 | raw file
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 1,
  6. "metadata": {
  7. "collapsed": true
  8. },
  9. "outputs": [],
  10. "source": [
  11. "import numpy as np\n",
  12. "from numpy.linalg import pinv\n",
  13. "\n",
  14. "import pandas as pd\n",
  15. "\n",
  16. "import networkx as nx\n",
  17. "import pydot\n",
  18. "from IPython.display import Image, display\n",
  19. "\n",
  20. "import matplotlib.pyplot as plt\n",
  21. "%matplotlib inline\n",
  22. "\n",
  23. "np.set_printoptions(precision=4, suppress=True)\n",
  24. "pd.set_option('precision', 4)\n",
  25. "pd.set_option('display.float_format', lambda x: '%.4f' % x)"
  26. ]
  27. },
  28. {
  29. "cell_type": "code",
  30. "execution_count": 2,
  31. "metadata": {
  32. "collapsed": true
  33. },
  34. "outputs": [],
  35. "source": [
  36. "%load_ext autoreload\n",
  37. "%autoreload 2"
  38. ]
  39. },
  40. {
  41. "cell_type": "code",
  42. "execution_count": 7,
  43. "metadata": {
  44. "collapsed": false
  45. },
  46. "outputs": [],
  47. "source": [
  48. "from algorithms import TDVarTraces\n",
  49. "from features import Int2Unary\n",
  50. "from simulation import trajectory_gen, compute_return"
  51. ]
  52. },
  53. {
  54. "cell_type": "code",
  55. "execution_count": 9,
  56. "metadata": {
  57. "collapsed": false
  58. },
  59. "outputs": [
  60. {
  61. "data": {
  62. "text/html": [
  63. "<div>\n",
  64. "<table border=\"1\" class=\"dataframe\">\n",
  65. " <thead>\n",
  66. " <tr style=\"text-align: right;\">\n",
  67. " <th></th>\n",
  68. " <th>true_value</th>\n",
  69. " <th>exp_value</th>\n",
  70. " <th>true_var</th>\n",
  71. " <th>exp_var</th>\n",
  72. " </tr>\n",
  73. " </thead>\n",
  74. " <tbody>\n",
  75. " <tr>\n",
  76. " <th>0</th>\n",
  77. " <td>0.0312</td>\n",
  78. " <td>0.0325</td>\n",
  79. " <td>0.0303</td>\n",
  80. " <td>0.0314</td>\n",
  81. " </tr>\n",
  82. " <tr>\n",
  83. " <th>1</th>\n",
  84. " <td>0.0625</td>\n",
  85. " <td>0.0651</td>\n",
  86. " <td>0.0586</td>\n",
  87. " <td>0.0609</td>\n",
  88. " </tr>\n",
  89. " <tr>\n",
  90. " <th>2</th>\n",
  91. " <td>0.1250</td>\n",
  92. " <td>0.1304</td>\n",
  93. " <td>0.1094</td>\n",
  94. " <td>0.1134</td>\n",
  95. " </tr>\n",
  96. " <tr>\n",
  97. " <th>3</th>\n",
  98. " <td>0.2500</td>\n",
  99. " <td>0.2582</td>\n",
  100. " <td>0.1875</td>\n",
  101. " <td>0.1916</td>\n",
  102. " </tr>\n",
  103. " <tr>\n",
  104. " <th>4</th>\n",
  105. " <td>0.5000</td>\n",
  106. " <td>0.5063</td>\n",
  107. " <td>0.2500</td>\n",
  108. " <td>0.2500</td>\n",
  109. " </tr>\n",
  110. " <tr>\n",
  111. " <th>5</th>\n",
  112. " <td>0.0000</td>\n",
  113. " <td>0.0000</td>\n",
  114. " <td>0.0000</td>\n",
  115. " <td>0.0000</td>\n",
  116. " </tr>\n",
  117. " </tbody>\n",
  118. "</table>\n",
  119. "</div>"
  120. ],
  121. "text/plain": [
  122. " true_value exp_value true_var exp_var\n",
  123. "0 0.0312 0.0325 0.0303 0.0314\n",
  124. "1 0.0625 0.0651 0.0586 0.0609\n",
  125. "2 0.1250 0.1304 0.1094 0.1134\n",
  126. "3 0.2500 0.2582 0.1875 0.1916\n",
  127. "4 0.5000 0.5063 0.2500 0.2500\n",
  128. "5 0.0000 0.0000 0.0000 0.0000"
  129. ]
  130. },
  131. "metadata": {},
  132. "output_type": "display_data"
  133. }
  134. ],
  135. "source": [
  136. "# Chicken problem solved analytically\n",
  137. "nstates = ns = 6\n",
  138. "I = np.eye(ns)\n",
  139. "\n",
  140. "# Probability of transitioning from state s_i --> s_j = P[i,j]\n",
  141. "P = np.diag(np.ones(ns-1), 1) * 0.5\n",
  142. "P[:,0] = 0.5\n",
  143. "P[-1, 0] = 1\n",
  144. "\n",
  145. "# Expected reward for transitioning from s_i --> s_j = R[i,j]\n",
  146. "R = np.zeros((ns, ns)) * 0.9\n",
  147. "R[-2, -1] = 1.0\n",
  148. "r = np.sum(P*R, axis=1)\n",
  149. "\n",
  150. "# Discount\n",
  151. "gmvec = np.ones(ns)\n",
  152. "gmvec[0] = 0\n",
  153. "G = np.diag(gmvec)\n",
  154. "\n",
  155. "# Bootstrapping\n",
  156. "lmvec = np.zeros(ns)\n",
  157. "L = np.diag(lmvec)\n",
  158. "\n",
  159. "# Function approximation\n",
  160. "X = np.eye(ns)\n",
  161. "\n",
  162. "# Value function\n",
  163. "v_pi = pinv(I - P @ G) @ r\n",
  164. "\n",
  165. "\n",
  166. "# From sobel, recursive expected variance contribution?\n",
  167. "# T = -v_pi**2\n",
  168. "# for i in range(ns):\n",
  169. "# for j in range(ns):\n",
  170. "# T[i] += P[i,j] * (R[i,j] + gmvec[j]*v_pi[j])**2\n",
  171. "\n",
  172. "# Alternatively,\n",
  173. "T = np.sum(P * (R + G @ v_pi)**2, axis=1) - v_pi**2\n",
  174. " \n",
  175. "# Variance (again from Sobel)\n",
  176. "v_var = pinv(I - P @ G @ G) @ T \n",
  177. "\n",
  178. "# Define the experiment\n",
  179. "nsteps = 100000\n",
  180. "gmfunc = lambda x: gmvec[x]\n",
  181. "lmfunc = lambda x: lmvec[x]\n",
  182. "phi = Int2Unary(nstates)\n",
  183. "\n",
  184. "# Generate a trajectory\n",
  185. "gen = trajectory_gen(P, R)\n",
  186. "slst = [next(gen) for i in range(nsteps)]\n",
  187. "\n",
  188. "# Compute the MC-return\n",
  189. "glst = compute_return(slst, gmfunc)\n",
  190. "xlst = [(phi(s), r, phi(sp)) for s, r, sp in slst] \n",
  191. "\n",
  192. "# Compare analytical and experimental values\n",
  193. "gdf = pd.DataFrame(glst, columns=['s', 'g', 'sp'])\n",
  194. "grouped = pd.groupby(gdf, 's')\n",
  195. "true_value = pd.Series(v_pi, name='true_value')\n",
  196. "true_var = pd.Series(v_var, name='true_var')\n",
  197. "exp_value = grouped.aggregate({'g': np.mean}).rename(columns={'g': 'exp_value'})\n",
  198. "exp_var = grouped.aggregate({'g': np.var}).rename(columns={'g': 'exp_var'})\n",
  199. "\n",
  200. "combined = pd.concat([true_value, exp_value, true_var, exp_var], axis=1)\n",
  201. "display(combined)"
  202. ]
  203. },
  204. {
  205. "cell_type": "code",
  206. "execution_count": 10,
  207. "metadata": {
  208. "collapsed": false
  209. },
  210. "outputs": [
  211. {
  212. "data": {
  213. "text/plain": [
  214. "(array([0]),)"
  215. ]
  216. },
  217. "execution_count": 10,
  218. "metadata": {},
  219. "output_type": "execute_result"
  220. }
  221. ],
  222. "source": []
  223. },
  224. {
  225. "cell_type": "code",
  226. "execution_count": 30,
  227. "metadata": {
  228. "collapsed": false
  229. },
  230. "outputs": [
  231. {
  232. "name": "stdout",
  233. "output_type": "stream",
  234. "text": [
  235. "Epoch: 0\n",
  236. "[ 0.0326 0.0617 0.1226 0.2289 0.5152 0. ]\n",
  237. "[ 0.0021 0.0081 0.0298 0.0966 0.2424 0. ]\n",
  238. "\n",
  239. "Epoch: 1\n",
  240. "[ 0.035 0.0663 0.1301 0.2453 0.5161 0. ]\n",
  241. "[ 0.0022 0.0084 0.0303 0.099 0.2504 0. ]\n",
  242. "\n",
  243. "Epoch: 2\n",
  244. "[ 0.0355 0.0679 0.1334 0.2534 0.5176 0. ]\n",
  245. "[ 0.0022 0.0083 0.0301 0.0989 0.2504 0. ]\n",
  246. "\n",
  247. "Epoch: 3\n",
  248. "[ 0.0349 0.0676 0.1338 0.2569 0.518 0. ]\n",
  249. "[ 0.0021 0.0081 0.0299 0.0987 0.2503 0. ]\n",
  250. "\n",
  251. "Epoch: 4\n",
  252. "[ 0.0341 0.0669 0.1333 0.2583 0.5178 0. ]\n",
  253. "[ 0.0021 0.0081 0.0297 0.0984 0.2502 0. ]\n",
  254. "\n",
  255. "Epoch: 5\n",
  256. "[ 0.0335 0.0662 0.1327 0.2589 0.5173 0. ]\n",
  257. "[ 0.0021 0.008 0.0295 0.0982 0.2501 0. ]\n",
  258. "\n",
  259. "Epoch: 6\n",
  260. "[ 0.0331 0.0657 0.1321 0.259 0.5166 0. ]\n",
  261. "[ 0.0021 0.0079 0.0294 0.098 0.2501 0. ]\n",
  262. "\n",
  263. "Epoch: 7\n",
  264. "[ 0.0328 0.0653 0.1316 0.259 0.516 0. ]\n",
  265. "[ 0.002 0.0079 0.0293 0.0979 0.25 0. ]\n",
  266. "\n",
  267. "Epoch: 8\n",
  268. "[ 0.0326 0.065 0.1312 0.2589 0.5154 0. ]\n",
  269. "[ 0.002 0.0079 0.0292 0.0978 0.25 0. ]\n",
  270. "\n",
  271. "Epoch: 9\n",
  272. "[ 0.0324 0.0648 0.1309 0.2588 0.5148 0. ]\n",
  273. "[ 0.002 0.0078 0.0291 0.0976 0.25 0. ]\n",
  274. "\n"
  275. ]
  276. }
  277. ],
  278. "source": [
  279. "# Setup\n",
  280. "num_epochs = 10\n",
  281. "\n",
  282. "# Algorithm\n",
  283. "num_features = nstates\n",
  284. "alg = TDVarTraces(num_features)\n",
  285. "\n",
  286. "# Representation\n",
  287. "phi = Int2Unary(num_features)\n",
  288. "\n",
  289. "# Parameters\n",
  290. "gamma = lambda x: gmvec[x]\n",
  291. "lmbda = lambda x: lmvec[x]\n",
  292. "\n",
  293. "# Simulation\n",
  294. "for epoch in range(num_epochs):\n",
  295. " alpha = 0.01/(1 + epoch)\n",
  296. " for step in slst:\n",
  297. " s, r, sp = step\n",
  298. " x = phi(s)\n",
  299. " xp = phi(sp)\n",
  300. " \n",
  301. " # Value update parameters\n",
  302. " gm = gamma(s)\n",
  303. " gm_p = gamma(sp)\n",
  304. " lm = lmbda(s)\n",
  305. " lm_p = lmbda(sp)\n",
  306. " \n",
  307. " # Variance update parameters\n",
  308. " v_gm = gamma(s)\n",
  309. " v_gm_p = gamma(sp)\n",
  310. " v_lm = 0.5\n",
  311. " v_lm_p = 0.5\n",
  312. " v_alpha = 0.001\n",
  313. " v_beta = 0.001\n",
  314. " v_eta = 0\n",
  315. " \n",
  316. " # Perform the update\n",
  317. " alg.update(x, r, xp, alpha, gm, gm_p, lm,\n",
  318. " v_gm, v_gm_p, v_lm, v_lm_p, v_alpha, v_beta, v_eta)\n",
  319. " \n",
  320. " # Tracking\n",
  321. " \n",
  322. " print(\"Epoch:\", epoch)\n",
  323. " print(alg.theta)\n",
  324. " print(alg.w)\n",
  325. " print()"
  326. ]
  327. },
  328. {
  329. "cell_type": "code",
  330. "execution_count": 21,
  331. "metadata": {
  332. "collapsed": false
  333. },
  334. "outputs": [
  335. {
  336. "data": {
  337. "text/html": [
  338. "<div>\n",
  339. "<table border=\"1\" class=\"dataframe\">\n",
  340. " <thead>\n",
  341. " <tr style=\"text-align: right;\">\n",
  342. " <th></th>\n",
  343. " <th>exp_value</th>\n",
  344. " <th>exp_var</th>\n",
  345. " </tr>\n",
  346. " <tr>\n",
  347. " <th>s</th>\n",
  348. " <th></th>\n",
  349. " <th></th>\n",
  350. " </tr>\n",
  351. " </thead>\n",
  352. " <tbody>\n",
  353. " <tr>\n",
  354. " <th>0</th>\n",
  355. " <td>0.0325</td>\n",
  356. " <td>0.0314</td>\n",
  357. " </tr>\n",
  358. " <tr>\n",
  359. " <th>1</th>\n",
  360. " <td>0.0651</td>\n",
  361. " <td>0.0609</td>\n",
  362. " </tr>\n",
  363. " <tr>\n",
  364. " <th>2</th>\n",
  365. " <td>0.1304</td>\n",
  366. " <td>0.1134</td>\n",
  367. " </tr>\n",
  368. " <tr>\n",
  369. " <th>3</th>\n",
  370. " <td>0.2582</td>\n",
  371. " <td>0.1916</td>\n",
  372. " </tr>\n",
  373. " <tr>\n",
  374. " <th>4</th>\n",
  375. " <td>0.5063</td>\n",
  376. " <td>0.2500</td>\n",
  377. " </tr>\n",
  378. " <tr>\n",
  379. " <th>5</th>\n",
  380. " <td>0.0000</td>\n",
  381. " <td>0.0000</td>\n",
  382. " </tr>\n",
  383. " </tbody>\n",
  384. "</table>\n",
  385. "</div>"
  386. ],
  387. "text/plain": [
  388. " exp_value exp_var\n",
  389. "s \n",
  390. "0 0.0325 0.0314\n",
  391. "1 0.0651 0.0609\n",
  392. "2 0.1304 0.1134\n",
  393. "3 0.2582 0.1916\n",
  394. "4 0.5063 0.2500\n",
  395. "5 0.0000 0.0000"
  396. ]
  397. },
  398. "execution_count": 21,
  399. "metadata": {},
  400. "output_type": "execute_result"
  401. }
  402. ],
  403. "source": [
  404. "pd.concat([exp_value, exp_var], axis=1)"
  405. ]
  406. },
  407. {
  408. "cell_type": "code",
  409. "execution_count": null,
  410. "metadata": {
  411. "collapsed": true
  412. },
  413. "outputs": [],
  414. "source": []
  415. }
  416. ],
  417. "metadata": {
  418. "anaconda-cloud": {},
  419. "kernelspec": {
  420. "display_name": "Python [py35]",
  421. "language": "python",
  422. "name": "Python [py35]"
  423. },
  424. "language_info": {
  425. "codemirror_mode": {
  426. "name": "ipython",
  427. "version": 3
  428. },
  429. "file_extension": ".py",
  430. "mimetype": "text/x-python",
  431. "name": "python",
  432. "nbconvert_exporter": "python",
  433. "pygments_lexer": "ipython3",
  434. "version": "3.5.2"
  435. }
  436. },
  437. "nbformat": 4,
  438. "nbformat_minor": 0
  439. }