/funsor/pyro/distribution.py

https://github.com/pyro-ppl/funsor · Python · 114 lines · 90 code · 15 blank · 9 comment · 9 complexity · 519171ebda6e669cd22f2a731e6c0eaf MD5 · raw file

  1. # Copyright Contributors to the Pyro project.
  2. # SPDX-License-Identifier: Apache-2.0
  3. from collections import OrderedDict
  4. import torch
  5. from pyro.distributions import TorchDistribution
  6. from torch.distributions import constraints
  7. from funsor.cnf import Contraction
  8. from funsor.delta import Delta
  9. from funsor.domains import Bint
  10. from funsor.interpreter import reinterpret
  11. from funsor.pyro.convert import DIM_TO_NAME, funsor_to_tensor, tensor_to_funsor
  12. from funsor.terms import Funsor, to_funsor
  13. class FunsorDistribution(TorchDistribution):
  14. """
  15. :class:`~torch.distributions.Distribution` wrapper around a
  16. :class:`~funsor.terms.Funsor` for use in Pyro code. This is typically used
  17. as a base class for specific funsor inference algorithms wrapped in a
  18. distribution interface.
  19. :param funsor.terms.Funsor funsor_dist: A funsor with an input named
  20. "value" that is treated as a random variable. The distribution should
  21. be normalized over "value".
  22. :param torch.Size batch_shape: The distribution's batch shape. This must
  23. be in the same order as the input of the ``funsor_dist``, but may
  24. contain extra dims of size 1.
  25. :param event_shape: The distribution's event shape.
  26. """
  27. arg_constraints = {}
  28. def __init__(
  29. self,
  30. funsor_dist,
  31. batch_shape=torch.Size(),
  32. event_shape=torch.Size(),
  33. dtype="real",
  34. validate_args=None,
  35. ):
  36. assert isinstance(funsor_dist, Funsor)
  37. assert isinstance(batch_shape, tuple)
  38. assert isinstance(event_shape, tuple)
  39. assert "value" in funsor_dist.inputs
  40. super(FunsorDistribution, self).__init__(
  41. batch_shape, event_shape, validate_args
  42. )
  43. self.funsor_dist = funsor_dist
  44. self.dtype = dtype
  45. @constraints.dependent_property
  46. def support(self):
  47. if self.dtype == "real":
  48. return constraints.real
  49. else:
  50. return constraints.integer_interval(0, self.dtype - 1)
  51. def log_prob(self, value):
  52. if self._validate_args:
  53. self._validate_sample(value)
  54. ndims = max(len(self.batch_shape), value.dim() - self.event_dim)
  55. value = tensor_to_funsor(value, event_output=self.event_dim, dtype=self.dtype)
  56. log_prob = reinterpret(self.funsor_dist(value=value))
  57. log_prob = funsor_to_tensor(log_prob, ndims=ndims)
  58. return log_prob
  59. def _sample_delta(self, sample_shape):
  60. sample_inputs = None
  61. if sample_shape:
  62. sample_inputs = OrderedDict()
  63. shape = sample_shape + self.batch_shape
  64. for dim in range(-len(shape), -len(self.batch_shape)):
  65. if shape[dim] > 1:
  66. sample_inputs[DIM_TO_NAME[dim]] = Bint[shape[dim]]
  67. delta = self.funsor_dist.sample(frozenset({"value"}), sample_inputs)
  68. if isinstance(delta, Contraction):
  69. assert len([d for d in delta.terms if isinstance(d, Delta)]) == 1
  70. delta = delta.terms[0]
  71. assert isinstance(delta, Delta)
  72. return delta
  73. @torch.no_grad()
  74. def sample(self, sample_shape=torch.Size()):
  75. delta = self._sample_delta(sample_shape)
  76. ndims = len(sample_shape) + len(self.batch_shape) + len(self.event_shape)
  77. value = funsor_to_tensor(delta.terms[0][1][0], ndims=ndims)
  78. return value.detach()
  79. def rsample(self, sample_shape=torch.Size()):
  80. delta = self._sample_delta(sample_shape)
  81. assert (
  82. not delta.log_density.requires_grad
  83. ), "distribution is not fully reparametrized"
  84. ndims = len(sample_shape) + len(self.batch_shape) + len(self.event_shape)
  85. value = funsor_to_tensor(delta.terms[0][1][0], ndims=ndims)
  86. return value
  87. def expand(self, batch_shape, _instance=None):
  88. new = self._get_checked_instance(type(self), _instance)
  89. batch_shape = torch.Size(batch_shape)
  90. funsor_dist = self.funsor_dist + tensor_to_funsor(torch.zeros(batch_shape))
  91. super(type(self), new).__init__(
  92. funsor_dist, batch_shape, self.event_shape, self.dtype, validate_args=False
  93. )
  94. new.validate_args = self.__dict__.get("_validate_args")
  95. return new
  96. @to_funsor.register(FunsorDistribution)
  97. def funsordistribution_to_funsor(pyro_dist, output=None, dim_to_name=None):
  98. raise NotImplementedError("TODO implement conversion for FunsorDistribution")