PageRenderTime 72ms CodeModel.GetById 43ms RepoModel.GetById 0ms app.codeStats 1ms

/ffc/uflacs/backends/ffc/access.py

https://bitbucket.org/chaffra/ffc
Python | 314 lines | 208 code | 61 blank | 45 comment | 53 complexity | 9a76ce64b439cb6613cdc7acf313ce65 MD5 | raw file
  1. # -*- coding: utf-8 -*-
  2. # Copyright (C) 2011-2016 Martin Sandve Alnæs
  3. #
  4. # This file is part of UFLACS.
  5. #
  6. # UFLACS is free software: you can redistribute it and/or modify
  7. # it under the terms of the GNU Lesser General Public License as published by
  8. # the Free Software Foundation, either version 3 of the License, or
  9. # (at your option) any later version.
  10. #
  11. # UFLACS is distributed in the hope that it will be useful,
  12. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  13. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  14. # GNU Lesser General Public License for more details.
  15. #
  16. # You should have received a copy of the GNU Lesser General Public License
  17. # along with UFLACS. If not, see <http://www.gnu.org/licenses/>
  18. """FFC/UFC specific variable access."""
  19. from ufl.corealg.multifunction import MultiFunction
  20. from ufl.permutation import build_component_numbering
  21. from ffc.log import error, warning
  22. from ffc.uflacs.backends.ffc.symbols import FFCBackendSymbols
  23. from ffc.uflacs.backends.ffc.common import physical_quadrature_integral_types
  24. class FFCBackendAccess(MultiFunction):
  25. """FFC specific cpp formatter class."""
  26. def __init__(self, ir, language, symbols, parameters):
  27. MultiFunction.__init__(self)
  28. # Store ir and parameters
  29. self.ir = ir
  30. self.entitytype = ir["entitytype"]
  31. self.integral_type = ir["integral_type"]
  32. self.language = language
  33. self.symbols = symbols
  34. self.parameters = parameters
  35. # === Rules for all modified terminal types ===
  36. def expr(self, e, mt, tabledata, num_points):
  37. error("Missing handler for type {0}.".format(e._ufl_class_.__name__))
  38. # === Rules for literal constants ===
  39. def zero(self, e, mt, tabledata, num_points):
  40. # We shouldn't have derivatives of constants left at this point
  41. assert not (mt.global_derivatives or mt.local_derivatives)
  42. # NB! UFL doesn't retain float/int type information for zeros...
  43. L = self.language
  44. return L.LiteralFloat(0.0)
  45. def int_value(self, e, mt, tabledata, num_points):
  46. # We shouldn't have derivatives of constants left at this point
  47. assert not (mt.global_derivatives or mt.local_derivatives)
  48. L = self.language
  49. return L.LiteralInt(int(e))
  50. def float_value(self, e, mt, tabledata, num_points):
  51. # We shouldn't have derivatives of constants left at this point
  52. assert not (mt.global_derivatives or mt.local_derivatives)
  53. L = self.language
  54. return L.LiteralFloat(float(e))
  55. def argument(self, e, mt, tabledata, num_points):
  56. L = self.language
  57. # Expecting only local derivatives and values here
  58. assert not mt.global_derivatives
  59. # assert mt.global_component is None
  60. # No need to store basis function value in its own variable, just get table value directly
  61. #uname, begin, end, ttype = tabledata
  62. uname, begin, end = tabledata
  63. table_types = self.ir["expr_irs"][num_points]["table_types"]
  64. ttype = table_types[uname]
  65. if ttype == "zeros":
  66. error("Not expecting zero arguments to get this far.")
  67. return L.LiteralFloat(0.0)
  68. elif ttype == "ones":
  69. warning("Should simplify ones arguments before getting this far.")
  70. return L.LiteralFloat(1.0)
  71. entity = self.symbols.entity(self.entitytype, mt.restriction)
  72. idof = self.symbols.argument_loop_index(mt.terminal.number())
  73. if ttype == "piecewise":
  74. iq = 0
  75. else:
  76. iq = self.symbols.quadrature_loop_index(num_points)
  77. uname = L.Symbol(uname)
  78. return uname[entity][iq][idof - begin]
  79. def coefficient(self, e, mt, tabledata, num_points):
  80. # TODO: Passing type along with tabledata would make a lot of code cleaner
  81. #uname, begin, end, ttype = tabledata
  82. uname, begin, end = tabledata
  83. table_types = self.ir["expr_irs"][num_points]["table_types"]
  84. ttype = table_types[uname]
  85. if ttype == "zeros":
  86. # FIXME: Remove at earlier stage so dependent code can also be removed
  87. warning("Not expecting zero coefficients to get this far.")
  88. L = self.language
  89. return L.LiteralFloat(0.0)
  90. elif ttype == "ones" and (end - begin) == 1:
  91. # f = 1.0 * f_i, just return direct reference to dof array at dof begin
  92. return self.symbols.coefficient_dof_access(mt.terminal, begin)
  93. else:
  94. # Return symbol, see definitions for computation
  95. return self.symbols.coefficient_value(mt) #, num_points)
  96. def quadrature_weight(self, e, mt, tabledata, num_points):
  97. weight = self.symbols.weights_array(num_points)
  98. iq = self.symbols.quadrature_loop_index(num_points)
  99. return weight[iq]
  100. def spatial_coordinate(self, e, mt, tabledata, num_points):
  101. #L = self.language
  102. if mt.global_derivatives:
  103. error("Not expecting derivatives of SpatialCoordinate.")
  104. if mt.local_derivatives:
  105. error("Not expecting derivatives of SpatialCoordinate.")
  106. if mt.averaged:
  107. error("Not expecting average of SpatialCoordinates.")
  108. if self.integral_type in physical_quadrature_integral_types:
  109. # Physical coordinates are available in given variables
  110. assert num_points is None
  111. x = self.symbols.points_array(num_points)
  112. iq = self.symbols.quadrature_loop_index(num_points)
  113. gdim, = mt.terminal.ufl_shape
  114. return x[iq * gdim + mt.flat_component]
  115. else:
  116. # Physical coordinates are computed by code generated in definitions
  117. return self.symbols.x_component(mt)
  118. def cell_coordinate(self, e, mt, tabledata, num_points):
  119. #L = self.language
  120. if mt.global_derivatives:
  121. error("Not expecting derivatives of CellCoordinate.")
  122. if mt.local_derivatives:
  123. error("Not expecting derivatives of CellCoordinate.")
  124. if mt.averaged:
  125. error("Not expecting average of CellCoordinate.")
  126. if self.integral_type == "cell" and not mt.restriction:
  127. X = self.symbols.points_array(num_points)
  128. if num_points == 1:
  129. return X[mt.flat_component]
  130. else:
  131. iq = self.symbols.quadrature_loop_index(num_points)
  132. tdim, = mt.terminal.ufl_shape
  133. return X[iq * tdim + mt.flat_component]
  134. else:
  135. # X should be computed from x or Xf symbolically instead of getting here
  136. error("Expecting reference cell coordinate to be symbolically rewritten.")
  137. def facet_coordinate(self, e, mt, tabledata, num_points):
  138. L = self.language
  139. if mt.global_derivatives:
  140. error("Not expecting derivatives of FacetCoordinate.")
  141. if mt.local_derivatives:
  142. error("Not expecting derivatives of FacetCoordinate.")
  143. if mt.averaged:
  144. error("Not expecting average of FacetCoordinate.")
  145. if mt.restriction:
  146. error("Not expecting restriction of FacetCoordinate.")
  147. if self.integral_type in ("interior_facet", "exterior_facet"):
  148. tdim, = mt.terminal.ufl_shape
  149. if tdim == 0:
  150. error("Vertices have no facet coordinates.")
  151. elif tdim == 1:
  152. # 0D vertex coordinate
  153. warning("Vertex coordinate is always 0, should get rid of this in ufl geometry lowering.")
  154. return L.LiteralFloat(0.0)
  155. Xf = self.points_array(num_points)
  156. iq = self.symbols.quadrature_loop_index(num_points)
  157. assert 0 <= mt.flat_component < (tdim-1)
  158. if tdim == 2:
  159. # 1D edge coordinate
  160. assert mt.flat_component == 0
  161. return Xf[iq]
  162. else:
  163. # The general case
  164. return Xf[iq * (tdim - 1) + mt.flat_component]
  165. else:
  166. # Xf should be computed from X or x symbolically instead of getting here
  167. error("Expecting reference facet coordinate to be symbolically rewritten.")
  168. def jacobian(self, e, mt, tabledata, num_points):
  169. L = self.language
  170. if mt.global_derivatives:
  171. error("Not expecting derivatives of Jacobian.")
  172. if mt.local_derivatives:
  173. error("Not expecting derivatives of Jacobian.")
  174. if mt.averaged:
  175. error("Not expecting average of Jacobian.")
  176. return self.symbols.J_component(mt)
  177. def reference_cell_volume(self, e, mt, tabledata, access):
  178. L = self.language
  179. cellname = mt.terminal.ufl_domain().ufl_cell().cellname()
  180. if cellname in ("interval", "triangle", "tetrahedron", "quadrilateral", "hexahedron"):
  181. return L.Symbol("{0}_reference_cell_volume".format(cellname))
  182. else:
  183. error("Unhandled cell types {0}.".format(cellname))
  184. def reference_facet_volume(self, e, mt, tabledata, access):
  185. L = self.language
  186. cellname = mt.terminal.ufl_domain().ufl_cell().cellname()
  187. if cellname in ("interval", "triangle", "tetrahedron", "quadrilateral", "hexahedron"):
  188. return L.Symbol("{0}_reference_facet_volume".format(cellname))
  189. else:
  190. error("Unhandled cell types {0}.".format(cellname))
  191. def reference_normal(self, e, mt, tabledata, access):
  192. L = self.language
  193. cellname = mt.terminal.ufl_domain().ufl_cell().cellname()
  194. if cellname in ("interval", "triangle", "tetrahedron", "quadrilateral", "hexahedron"):
  195. table = L.Symbol("{0}_reference_facet_normals".format(cellname))
  196. facet = self.symbols.entity("facet", mt.restriction)
  197. return table[facet][mt.component[0]]
  198. else:
  199. error("Unhandled cell types {0}.".format(cellname))
  200. def cell_facet_jacobian(self, e, mt, tabledata, num_points):
  201. L = self.language
  202. cellname = mt.terminal.ufl_domain().ufl_cell().cellname()
  203. if cellname in ("triangle", "tetrahedron", "quadrilateral", "hexahedron"):
  204. table = L.Symbol("{0}_reference_facet_jacobian".format(cellname))
  205. facet = self.symbols.entity("facet", mt.restriction)
  206. return table[facet][mt.component[0]][mt.component[1]]
  207. elif cellname == "interval":
  208. error("The reference facet jacobian doesn't make sense for interval cell.")
  209. else:
  210. error("Unhandled cell types {0}.".format(cellname))
  211. def cell_edge_vectors(self, e, mt, tabledata, num_points):
  212. L = self.language
  213. cellname = mt.terminal.ufl_domain().ufl_cell().cellname()
  214. if cellname in ("triangle", "tetrahedron", "quadrilateral", "hexahedron"):
  215. table = L.Symbol("{0}_reference_edge_vectors".format(cellname))
  216. return table[mt.component[0]][mt.component[1]]
  217. elif cellname == "interval":
  218. error("The reference cell edge vectors doesn't make sense for interval cell.")
  219. else:
  220. error("Unhandled cell types {0}.".format(cellname))
  221. def facet_edge_vectors(self, e, mt, tabledata, num_points):
  222. L = self.language
  223. cellname = mt.terminal.ufl_domain().ufl_cell().cellname()
  224. if cellname in ("tetrahedron", "hexahedron"):
  225. table = L.Symbol("{0}_reference_edge_vectors".format(cellname))
  226. facet = self.symbols.entity("facet", mt.restriction)
  227. return table[facet][mt.component[0]][mt.component[1]]
  228. elif cellname in ("interval", "triangle", "quadrilateral"):
  229. error("The reference cell facet edge vectors doesn't make sense for interval or triangle cell.")
  230. else:
  231. error("Unhandled cell types {0}.".format(cellname))
  232. def cell_orientation(self, e, mt, tabledata, num_points):
  233. # Error if not in manifold case:
  234. domain = mt.terminal.ufl_domain()
  235. assert domain.geometric_dimension() > domain.topological_dimension()
  236. return self.symbols.cell_orientation_internal(mt.restriction)
  237. def facet_orientation(self, e, mt, tabledata, num_points):
  238. L = self.language
  239. cellname = mt.terminal.ufl_domain().ufl_cell().cellname()
  240. if cellname not in ("interval", "triangle", "tetrahedron"):
  241. error("Unhandled cell types {0}.".format(cellname))
  242. table = L.Symbol("{0}_facet_orientations".format(cellname))
  243. facet = self.symbols.entity("facet", mt.restriction)
  244. return table[facet]
  245. def _expect_symbolic_lowering(self, e, mt, tabledata, num_points):
  246. error("Expecting {0} to be replaced in symbolic preprocessing.".format(type(e)))
  247. facet_normal = _expect_symbolic_lowering
  248. cell_normal = _expect_symbolic_lowering
  249. jacobian_inverse = _expect_symbolic_lowering
  250. jacobian_determinant = _expect_symbolic_lowering
  251. facet_jacobian = _expect_symbolic_lowering
  252. facet_jacobian_inverse = _expect_symbolic_lowering
  253. facet_jacobian_determinant = _expect_symbolic_lowering