/examples/airline_demo/airline_demo/resources.py

https://github.com/dagster-io/dagster · Python · 140 lines · 116 code · 24 blank · 0 comment · 4 complexity · e15271d91255d2a801aea2547be2c282 MD5 · raw file

  1. from collections import namedtuple
  2. import sqlalchemy
  3. from dagster import Field, IntSource, StringSource, resource
  4. DbInfo = namedtuple("DbInfo", "engine url jdbc_url dialect load_table host db_name")
  5. def create_redshift_db_url(username, password, hostname, port, db_name, jdbc=True):
  6. if jdbc:
  7. db_url = (
  8. "jdbc:postgresql://{hostname}:{port}/{db_name}?"
  9. "user={username}&password={password}".format(
  10. username=username, password=password, hostname=hostname, port=port, db_name=db_name
  11. )
  12. )
  13. else:
  14. db_url = "redshift+psycopg2://{username}:{password}@{hostname}:{port}/{db_name}".format(
  15. username=username, password=password, hostname=hostname, port=port, db_name=db_name
  16. )
  17. return db_url
  18. def create_redshift_engine(db_url):
  19. return sqlalchemy.create_engine(db_url)
  20. def create_postgres_db_url(username, password, hostname, port, db_name, jdbc=True):
  21. if jdbc:
  22. db_url = (
  23. "jdbc:postgresql://{hostname}:{port}/{db_name}?"
  24. "user={username}&password={password}".format(
  25. username=username, password=password, hostname=hostname, port=port, db_name=db_name
  26. )
  27. )
  28. else:
  29. db_url = "postgresql://{username}:{password}@{hostname}:{port}/{db_name}".format(
  30. username=username, password=password, hostname=hostname, port=port, db_name=db_name
  31. )
  32. return db_url
  33. def create_postgres_engine(db_url):
  34. return sqlalchemy.create_engine(db_url)
  35. @resource(
  36. {
  37. "username": Field(StringSource),
  38. "password": Field(StringSource),
  39. "hostname": Field(StringSource),
  40. "port": Field(IntSource, is_required=False, default_value=5439),
  41. "db_name": Field(StringSource),
  42. "s3_temp_dir": Field(str),
  43. }
  44. )
  45. def redshift_db_info_resource(init_context):
  46. host = init_context.resource_config["hostname"]
  47. db_name = init_context.resource_config["db_name"]
  48. db_url_jdbc = create_redshift_db_url(
  49. username=init_context.resource_config["username"],
  50. password=init_context.resource_config["password"],
  51. hostname=host,
  52. port=init_context.resource_config["port"],
  53. db_name=db_name,
  54. )
  55. db_url = create_redshift_db_url(
  56. username=init_context.resource_config["username"],
  57. password=init_context.resource_config["password"],
  58. hostname=host,
  59. port=init_context.resource_config["port"],
  60. db_name=db_name,
  61. jdbc=False,
  62. )
  63. s3_temp_dir = init_context.resource_config["s3_temp_dir"]
  64. def _do_load(data_frame, table_name):
  65. data_frame.write.format("com.databricks.spark.redshift").option(
  66. "tempdir", s3_temp_dir
  67. ).mode("overwrite").jdbc(db_url_jdbc, table_name)
  68. return DbInfo(
  69. url=db_url,
  70. jdbc_url=db_url_jdbc,
  71. engine=create_redshift_engine(db_url),
  72. dialect="redshift",
  73. load_table=_do_load,
  74. host=host,
  75. db_name=db_name,
  76. )
  77. @resource(
  78. {
  79. "username": Field(StringSource),
  80. "password": Field(StringSource),
  81. "hostname": Field(StringSource),
  82. "port": Field(IntSource, is_required=False, default_value=5432),
  83. "db_name": Field(StringSource),
  84. }
  85. )
  86. def postgres_db_info_resource(init_context):
  87. host = init_context.resource_config["hostname"]
  88. db_name = init_context.resource_config["db_name"]
  89. db_url_jdbc = create_postgres_db_url(
  90. username=init_context.resource_config["username"],
  91. password=init_context.resource_config["password"],
  92. hostname=host,
  93. port=init_context.resource_config["port"],
  94. db_name=db_name,
  95. )
  96. db_url = create_postgres_db_url(
  97. username=init_context.resource_config["username"],
  98. password=init_context.resource_config["password"],
  99. hostname=host,
  100. port=init_context.resource_config["port"],
  101. db_name=db_name,
  102. jdbc=False,
  103. )
  104. def _do_load(data_frame, table_name):
  105. data_frame.write.option("driver", "org.postgresql.Driver").mode("overwrite").jdbc(
  106. db_url_jdbc, table_name
  107. )
  108. return DbInfo(
  109. url=db_url,
  110. jdbc_url=db_url_jdbc,
  111. engine=create_postgres_engine(db_url),
  112. dialect="postgres",
  113. load_table=_do_load,
  114. host=host,
  115. db_name=db_name,
  116. )