/rlpark.plugin.rltoys/jvsrc/rltoys/environments/stategraph/FSGAgentState.java

https://github.com/pilarski/rlpark · Java · 271 lines · 239 code · 32 blank · 0 comment · 47 complexity · 951f53807e1fe4df0d7c9904ba73b87f MD5 · raw file

  1. package rltoys.environments.stategraph;
  2. import java.util.LinkedHashMap;
  3. import java.util.LinkedHashSet;
  4. import java.util.Map;
  5. import java.util.Set;
  6. import org.apache.commons.math.linear.Array2DRowRealMatrix;
  7. import org.apache.commons.math.linear.ArrayRealVector;
  8. import org.apache.commons.math.linear.LUDecompositionImpl;
  9. import org.apache.commons.math.linear.RealMatrix;
  10. import rltoys.algorithms.representations.acting.Policy;
  11. import rltoys.algorithms.representations.actions.Action;
  12. import rltoys.algorithms.representations.actions.StateToStateAction;
  13. import rltoys.environments.stategraph.FiniteStateGraph.StepData;
  14. import rltoys.math.vector.RealVector;
  15. import rltoys.math.vector.implementations.PVector;
  16. public class FSGAgentState implements StateToStateAction {
  17. private static final long serialVersionUID = -6312948577339609928L;
  18. public final int size;
  19. private final Map<GraphState, Integer> stateIndexes;
  20. private final FiniteStateGraph graph;
  21. private final PVector featureState;
  22. public FSGAgentState(FiniteStateGraph graph) {
  23. this.graph = graph;
  24. stateIndexes = indexStates(graph.states());
  25. size = nbNonAbsorbingState();
  26. featureState = new PVector(size);
  27. }
  28. private Map<GraphState, Integer> indexStates(GraphState[] states) {
  29. Map<GraphState, Integer> stateIndexes = new LinkedHashMap<GraphState, Integer>();
  30. int ci = 0;
  31. for (GraphState state : states) {
  32. GraphState s = state;
  33. if (!s.hasNextState())
  34. continue;
  35. stateIndexes.put(s, ci);
  36. ci++;
  37. }
  38. return stateIndexes;
  39. }
  40. public StepData step() {
  41. StepData stepData = graph.step();
  42. if (stepData.s_t != null && stepData.s_t.hasNextState())
  43. featureState.data[stateIndexes.get(stepData.s_t)] = 0;
  44. if (stepData.s_tp1 != null && stepData.s_tp1.hasNextState())
  45. featureState.data[stateIndexes.get(stepData.s_tp1)] = 1;
  46. return stepData;
  47. }
  48. public PVector currentFeatureState() {
  49. if (graph.currentState() == null)
  50. return new PVector(size);
  51. return featureState;
  52. }
  53. private RealMatrix createIdentityMatrix(int size) {
  54. RealMatrix phi = new Array2DRowRealMatrix(size, size);
  55. for (int i = 0; i < size; i++)
  56. phi.setEntry(i, i, 1.0);
  57. return phi;
  58. }
  59. public RealMatrix createPhi() {
  60. RealMatrix result = new Array2DRowRealMatrix(nbStates(), nbNonAbsorbingState());
  61. for (int i = 0; i < nbStates(); i++)
  62. result.setRow(i, getFeatureVector(states()[i]).data);
  63. return result;
  64. }
  65. private PVector getFeatureVector(GraphState graphState) {
  66. PVector result = new PVector(nbNonAbsorbingState());
  67. int ci = 0;
  68. for (int i = 0; i < nbStates(); i++) {
  69. GraphState s = states()[i];
  70. if (!s.hasNextState())
  71. continue;
  72. if (s == graphState)
  73. result.data[ci] = 1;
  74. ci++;
  75. }
  76. return result;
  77. }
  78. public double[] computeSolution(Policy policy, double gamma, double lambda) {
  79. RealMatrix phi = createPhi();
  80. RealMatrix p = createTransitionProbablityMatrix(policy);
  81. ArrayRealVector d = createStateDistribution(p);
  82. RealMatrix d_pi = createStateDistributionMatrix(d);
  83. RealMatrix p_lambda = computePLambda(p, gamma, lambda);
  84. ArrayRealVector r_bar = computeAverageReward(p);
  85. RealMatrix A = computeA(phi, d_pi, gamma, p_lambda);
  86. ArrayRealVector b = computeB(phi, d_pi, p, r_bar, gamma, lambda);
  87. RealMatrix minusAInverse = new LUDecompositionImpl(A).getSolver().getInverse().scalarMultiply(-1);
  88. return minusAInverse.operate(b).getData();
  89. }
  90. private ArrayRealVector computeB(RealMatrix phi, RealMatrix dPi, RealMatrix p, ArrayRealVector rBar, double gamma,
  91. double lambda) {
  92. RealMatrix inv = computeIdMinusGammaLambdaP(p, gamma, lambda);
  93. return (ArrayRealVector) phi.transpose().operate(dPi.operate(inv.operate(rBar)));
  94. }
  95. private RealMatrix computeA(RealMatrix phi, RealMatrix dPi, double gamma, RealMatrix pLambda) {
  96. RealMatrix id = createIdentityMatrix(phi.getRowDimension());
  97. return phi.transpose().multiply(dPi.multiply(pLambda.scalarMultiply(gamma).subtract(id).multiply(phi)));
  98. }
  99. private ArrayRealVector computeAverageReward(RealMatrix p) {
  100. ArrayRealVector result = new ArrayRealVector(p.getColumnDimension());
  101. for (int i = 0; i < nbStates(); i++) {
  102. if (!states()[i].hasNextState())
  103. continue;
  104. double sum = 0;
  105. for (int j = 0; j < nbStates(); j++)
  106. sum += p.getEntry(i, j) * states()[j].reward;
  107. result.setEntry(i, sum);
  108. }
  109. return result;
  110. }
  111. private RealMatrix computePLambda(RealMatrix p, double gamma, double lambda) {
  112. RealMatrix inv = computeIdMinusGammaLambdaP(p, gamma, lambda);
  113. return inv.multiply(p).scalarMultiply(1 - lambda);
  114. }
  115. private RealMatrix computeIdMinusGammaLambdaP(RealMatrix p, double gamma, double lambda) {
  116. RealMatrix id = createIdentityMatrix(p.getColumnDimension());
  117. return new LUDecompositionImpl(id.subtract(p.scalarMultiply(lambda * gamma))).getSolver().getInverse();
  118. }
  119. private RealMatrix createStateDistributionMatrix(ArrayRealVector d) {
  120. RealMatrix d_pi = new Array2DRowRealMatrix(nbStates(), nbStates());
  121. int ci = 0;
  122. for (int i = 0; i < nbStates(); i++) {
  123. GraphState s = states()[i];
  124. if (!s.hasNextState())
  125. continue;
  126. d_pi.setEntry(i, i, d.getEntry(ci));
  127. ci++;
  128. }
  129. return d_pi;
  130. }
  131. private ArrayRealVector createStateDistribution(RealMatrix p) {
  132. RealMatrix p_copy = p.copy();
  133. p_copy = removeColumnAndRow(p_copy, absorbingStatesSet());
  134. assert p_copy.getColumnDimension() == p_copy.getRowDimension();
  135. RealMatrix id = createIdentityMatrix(p_copy.getColumnDimension());
  136. RealMatrix inv = new LUDecompositionImpl(id.subtract(p_copy)).getSolver().getInverse();
  137. RealMatrix mu = createInitialStateDistribution();
  138. RealMatrix visits = mu.multiply(inv);
  139. double sum = 0;
  140. for (int i = 0; i < visits.getColumnDimension(); i++)
  141. sum += visits.getEntry(0, i);
  142. return (ArrayRealVector) visits.scalarMultiply(1 / sum).getRowVector(0);
  143. }
  144. private Set<Integer> absorbingStatesSet() {
  145. Set<Integer> endStates = new LinkedHashSet<Integer>();
  146. for (int i = 0; i < nbStates(); i++)
  147. if (!states()[i].hasNextState())
  148. endStates.add(i);
  149. return endStates;
  150. }
  151. private int nbNonAbsorbingState() {
  152. return stateIndexes.size();
  153. }
  154. private RealMatrix removeColumnAndRow(RealMatrix m, Set<Integer> absorbingState) {
  155. RealMatrix result = new Array2DRowRealMatrix(nbNonAbsorbingState(), nbNonAbsorbingState());
  156. int ci = 0;
  157. for (int i = 0; i < m.getRowDimension(); i++) {
  158. if (absorbingState.contains(i))
  159. continue;
  160. int cj = 0;
  161. for (int j = 0; j < m.getColumnDimension(); j++) {
  162. if (absorbingState.contains(j))
  163. continue;
  164. result.setEntry(ci, cj, m.getEntry(i, j));
  165. cj++;
  166. }
  167. ci++;
  168. }
  169. return result;
  170. }
  171. private RealMatrix createInitialStateDistribution() {
  172. double[] numbers = new double[nbNonAbsorbingState()];
  173. int ci = 0;
  174. for (int i = 0; i < nbStates(); i++) {
  175. GraphState s = states()[i];
  176. if (!s.hasNextState())
  177. continue;
  178. if (s != graph.initialState())
  179. numbers[ci] = 0.0;
  180. else
  181. numbers[ci] = 1.0;
  182. ci++;
  183. }
  184. RealMatrix result = new Array2DRowRealMatrix(1, numbers.length);
  185. for (int i = 0; i < numbers.length; i++)
  186. result.setEntry(0, i, numbers[i]);
  187. return result;
  188. }
  189. private RealMatrix createTransitionProbablityMatrix(Policy policy) {
  190. RealMatrix p = new Array2DRowRealMatrix(nbStates(), nbStates());
  191. for (int si = 0; si < nbStates(); si++) {
  192. GraphState s_t = states()[si];
  193. for (Action a : graph.actions()) {
  194. double pa = policy.pi(s_t.v(), a);
  195. GraphState s_tp1 = s_t.nextState(a);
  196. if (s_tp1 != null)
  197. p.setEntry(si, graph.indexOf(s_tp1), pa);
  198. }
  199. }
  200. for (Integer absorbingState : absorbingStatesSet())
  201. p.setEntry(absorbingState, absorbingState, 1.0);
  202. return p;
  203. }
  204. private int nbStates() {
  205. return graph.nbStates();
  206. }
  207. private GraphState[] states() {
  208. return graph.states();
  209. }
  210. public Map<GraphState, Integer> stateIndexes() {
  211. return stateIndexes;
  212. }
  213. public FiniteStateGraph graph() {
  214. return graph;
  215. }
  216. public PVector featureState(GraphState s) {
  217. PVector result = new PVector(size);
  218. if (s != null && s.hasNextState())
  219. result.data[stateIndexes.get(s)] = 1;
  220. return result;
  221. }
  222. @Override
  223. public PVector stateAction(RealVector s, Action a) {
  224. PVector sa = new PVector(nbNonAbsorbingState() * graph.actions().length);
  225. if (s == null)
  226. return sa;
  227. GraphState sg = graph.state(s);
  228. for (int ai = 0; ai < graph.actions().length; ai++)
  229. if (graph.actions()[ai] == a) {
  230. sa.setEntry(ai * nbNonAbsorbingState() + stateIndexes.get(sg), 1);
  231. return sa;
  232. }
  233. return null;
  234. }
  235. @Override
  236. public int vectorSize() {
  237. return graph.actions().length * nbNonAbsorbingState();
  238. }
  239. }