/src/transformers/__init__.py

https://github.com/huggingface/pytorch-pretrained-BERT · Python · 741 lines · 691 code · 25 blank · 25 comment · 8 complexity · 4a1abc8e89013b93bed5664cbc8d8f1e MD5 · raw file

  1. # flake8: noqa
  2. # There's no way to ignore "F401 '...' imported but unused" warnings in this
  3. # module, but to preserve other warnings. So, don't check this module at all.
  4. __version__ = "3.3.0"
  5. # Work around to update TensorFlow's absl.logging threshold which alters the
  6. # default Python logging output behavior when present.
  7. # see: https://github.com/abseil/abseil-py/issues/99
  8. # and: https://github.com/tensorflow/tensorflow/issues/26691#issuecomment-500369493
  9. try:
  10. import absl.logging
  11. except ImportError:
  12. pass
  13. else:
  14. absl.logging.set_verbosity("info")
  15. absl.logging.set_stderrthreshold("info")
  16. absl.logging._warn_preinit_stderr = False
  17. # Integrations: this needs to come before other ml imports
  18. # in order to allow any 3rd-party code to initialize properly
  19. from .integrations import ( # isort:skip
  20. is_comet_available,
  21. is_optuna_available,
  22. is_ray_available,
  23. is_tensorboard_available,
  24. is_wandb_available,
  25. )
  26. # Configurations
  27. from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
  28. from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, AutoConfig
  29. from .configuration_bart import BartConfig
  30. from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
  31. from .configuration_bert_generation import BertGenerationConfig
  32. from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
  33. from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
  34. from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
  35. from .configuration_dpr import DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, DPRConfig
  36. from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig
  37. from .configuration_encoder_decoder import EncoderDecoderConfig
  38. from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
  39. from .configuration_fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig
  40. from .configuration_funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig
  41. from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
  42. from .configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig
  43. from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
  44. from .configuration_lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig
  45. from .configuration_marian import MarianConfig
  46. from .configuration_mbart import MBartConfig
  47. from .configuration_mmbt import MMBTConfig
  48. from .configuration_mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig
  49. from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
  50. from .configuration_pegasus import PegasusConfig
  51. from .configuration_rag import RagConfig
  52. from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
  53. from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
  54. from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
  55. from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
  56. from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
  57. from .configuration_utils import PretrainedConfig
  58. from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
  59. from .configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
  60. from .configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
  61. from .data import (
  62. DataProcessor,
  63. InputExample,
  64. InputFeatures,
  65. SingleSentenceClassificationProcessor,
  66. SquadExample,
  67. SquadFeatures,
  68. SquadV1Processor,
  69. SquadV2Processor,
  70. glue_convert_examples_to_features,
  71. glue_output_modes,
  72. glue_processors,
  73. glue_tasks_num_labels,
  74. is_sklearn_available,
  75. squad_convert_examples_to_features,
  76. xnli_output_modes,
  77. xnli_processors,
  78. xnli_tasks_num_labels,
  79. )
  80. # Files and general utilities
  81. from .file_utils import (
  82. CONFIG_NAME,
  83. MODEL_CARD_NAME,
  84. PYTORCH_PRETRAINED_BERT_CACHE,
  85. PYTORCH_TRANSFORMERS_CACHE,
  86. TF2_WEIGHTS_NAME,
  87. TF_WEIGHTS_NAME,
  88. TRANSFORMERS_CACHE,
  89. WEIGHTS_NAME,
  90. add_end_docstrings,
  91. add_start_docstrings,
  92. cached_path,
  93. is_apex_available,
  94. is_datasets_available,
  95. is_faiss_available,
  96. is_psutil_available,
  97. is_py3nvml_available,
  98. is_tf_available,
  99. is_torch_available,
  100. is_torch_tpu_available,
  101. )
  102. from .hf_argparser import HfArgumentParser
  103. # Model Cards
  104. from .modelcard import ModelCard
  105. # TF 2.0 <=> PyTorch conversion utilities
  106. from .modeling_tf_pytorch_utils import (
  107. convert_tf_weight_name_to_pt_weight_name,
  108. load_pytorch_checkpoint_in_tf2_model,
  109. load_pytorch_model_in_tf2_model,
  110. load_pytorch_weights_in_tf2_model,
  111. load_tf2_checkpoint_in_pytorch_model,
  112. load_tf2_model_in_pytorch_model,
  113. load_tf2_weights_in_pytorch_model,
  114. )
  115. # Pipelines
  116. from .pipelines import (
  117. Conversation,
  118. ConversationalPipeline,
  119. CsvPipelineDataFormat,
  120. FeatureExtractionPipeline,
  121. FillMaskPipeline,
  122. JsonPipelineDataFormat,
  123. NerPipeline,
  124. PipedPipelineDataFormat,
  125. Pipeline,
  126. PipelineDataFormat,
  127. QuestionAnsweringPipeline,
  128. SummarizationPipeline,
  129. Text2TextGenerationPipeline,
  130. TextClassificationPipeline,
  131. TextGenerationPipeline,
  132. TokenClassificationPipeline,
  133. TranslationPipeline,
  134. ZeroShotClassificationPipeline,
  135. pipeline,
  136. )
  137. # Retriever
  138. from .retrieval_rag import RagRetriever
  139. # Tokenizers
  140. from .tokenization_albert import AlbertTokenizer
  141. from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
  142. from .tokenization_bart import BartTokenizer, BartTokenizerFast
  143. from .tokenization_bert import BasicTokenizer, BertTokenizer, BertTokenizerFast, WordpieceTokenizer
  144. from .tokenization_bert_generation import BertGenerationTokenizer
  145. from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer
  146. from .tokenization_bertweet import BertweetTokenizer
  147. from .tokenization_camembert import CamembertTokenizer
  148. from .tokenization_ctrl import CTRLTokenizer
  149. from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast
  150. from .tokenization_dpr import (
  151. DPRContextEncoderTokenizer,
  152. DPRContextEncoderTokenizerFast,
  153. DPRQuestionEncoderTokenizer,
  154. DPRQuestionEncoderTokenizerFast,
  155. DPRReaderTokenizer,
  156. DPRReaderTokenizerFast,
  157. )
  158. from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast
  159. from .tokenization_flaubert import FlaubertTokenizer
  160. from .tokenization_fsmt import FSMTTokenizer
  161. from .tokenization_funnel import FunnelTokenizer, FunnelTokenizerFast
  162. from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
  163. from .tokenization_layoutlm import LayoutLMTokenizer, LayoutLMTokenizerFast
  164. from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast
  165. from .tokenization_lxmert import LxmertTokenizer, LxmertTokenizerFast
  166. from .tokenization_mbart import MBartTokenizer
  167. from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFast
  168. from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
  169. from .tokenization_pegasus import PegasusTokenizer
  170. from .tokenization_phobert import PhobertTokenizer
  171. from .tokenization_rag import RagTokenizer
  172. from .tokenization_reformer import ReformerTokenizer
  173. from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
  174. from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
  175. from .tokenization_t5 import T5Tokenizer
  176. from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer, TransfoXLTokenizerFast
  177. from .tokenization_utils import PreTrainedTokenizer
  178. from .tokenization_utils_base import (
  179. BatchEncoding,
  180. CharSpan,
  181. PreTrainedTokenizerBase,
  182. SpecialTokensMixin,
  183. TensorType,
  184. TokenSpan,
  185. )
  186. from .tokenization_utils_fast import PreTrainedTokenizerFast
  187. from .tokenization_xlm import XLMTokenizer
  188. from .tokenization_xlm_roberta import XLMRobertaTokenizer
  189. from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
  190. # Trainer
  191. from .trainer_utils import EvalPrediction, set_seed
  192. from .training_args import TrainingArguments
  193. from .training_args_tf import TFTrainingArguments
  194. from .utils import logging
  195. logger = logging.get_logger(__name__) # pylint: disable=invalid-name
  196. if is_sklearn_available():
  197. from .data import glue_compute_metrics, xnli_compute_metrics
  198. # Modeling
  199. if is_torch_available():
  200. # Benchmarks
  201. from .benchmark.benchmark import PyTorchBenchmark
  202. from .benchmark.benchmark_args import PyTorchBenchmarkArguments
  203. from .data.data_collator import (
  204. DataCollator,
  205. DataCollatorForLanguageModeling,
  206. DataCollatorForNextSentencePrediction,
  207. DataCollatorForPermutationLanguageModeling,
  208. DataCollatorForSOP,
  209. DataCollatorWithPadding,
  210. default_data_collator,
  211. )
  212. from .data.datasets import (
  213. GlueDataset,
  214. GlueDataTrainingArguments,
  215. LineByLineTextDataset,
  216. LineByLineWithSOPTextDataset,
  217. SquadDataset,
  218. SquadDataTrainingArguments,
  219. TextDataset,
  220. TextDatasetForNextSentencePrediction,
  221. )
  222. from .generation_utils import top_k_top_p_filtering
  223. from .modeling_albert import (
  224. ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
  225. AlbertForMaskedLM,
  226. AlbertForMultipleChoice,
  227. AlbertForPreTraining,
  228. AlbertForQuestionAnswering,
  229. AlbertForSequenceClassification,
  230. AlbertForTokenClassification,
  231. AlbertModel,
  232. AlbertPreTrainedModel,
  233. load_tf_weights_in_albert,
  234. )
  235. from .modeling_auto import (
  236. MODEL_FOR_CAUSAL_LM_MAPPING,
  237. MODEL_FOR_MASKED_LM_MAPPING,
  238. MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
  239. MODEL_FOR_PRETRAINING_MAPPING,
  240. MODEL_FOR_QUESTION_ANSWERING_MAPPING,
  241. MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
  242. MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
  243. MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
  244. MODEL_MAPPING,
  245. MODEL_WITH_LM_HEAD_MAPPING,
  246. AutoModel,
  247. AutoModelForCausalLM,
  248. AutoModelForMaskedLM,
  249. AutoModelForMultipleChoice,
  250. AutoModelForPreTraining,
  251. AutoModelForQuestionAnswering,
  252. AutoModelForSeq2SeqLM,
  253. AutoModelForSequenceClassification,
  254. AutoModelForTokenClassification,
  255. AutoModelWithLMHead,
  256. )
  257. from .modeling_bart import (
  258. BART_PRETRAINED_MODEL_ARCHIVE_LIST,
  259. BartForConditionalGeneration,
  260. BartForQuestionAnswering,
  261. BartForSequenceClassification,
  262. BartModel,
  263. PretrainedBartModel,
  264. )
  265. from .modeling_bert import (
  266. BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
  267. BertForMaskedLM,
  268. BertForMultipleChoice,
  269. BertForNextSentencePrediction,
  270. BertForPreTraining,
  271. BertForQuestionAnswering,
  272. BertForSequenceClassification,
  273. BertForTokenClassification,
  274. BertLayer,
  275. BertLMHeadModel,
  276. BertModel,
  277. BertPreTrainedModel,
  278. load_tf_weights_in_bert,
  279. )
  280. from .modeling_bert_generation import (
  281. BertGenerationDecoder,
  282. BertGenerationEncoder,
  283. load_tf_weights_in_bert_generation,
  284. )
  285. from .modeling_camembert import (
  286. CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
  287. CamembertForCausalLM,
  288. CamembertForMaskedLM,
  289. CamembertForMultipleChoice,
  290. CamembertForQuestionAnswering,
  291. CamembertForSequenceClassification,
  292. CamembertForTokenClassification,
  293. CamembertModel,
  294. )
  295. from .modeling_ctrl import CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, CTRLLMHeadModel, CTRLModel, CTRLPreTrainedModel
  296. from .modeling_distilbert import (
  297. DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
  298. DistilBertForMaskedLM,
  299. DistilBertForMultipleChoice,
  300. DistilBertForQuestionAnswering,
  301. DistilBertForSequenceClassification,
  302. DistilBertForTokenClassification,
  303. DistilBertModel,
  304. DistilBertPreTrainedModel,
  305. )
  306. from .modeling_dpr import (
  307. DPRContextEncoder,
  308. DPRPretrainedContextEncoder,
  309. DPRPretrainedQuestionEncoder,
  310. DPRPretrainedReader,
  311. DPRQuestionEncoder,
  312. DPRReader,
  313. )
  314. from .modeling_electra import (
  315. ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
  316. ElectraForMaskedLM,
  317. ElectraForMultipleChoice,
  318. ElectraForPreTraining,
  319. ElectraForQuestionAnswering,
  320. ElectraForSequenceClassification,
  321. ElectraForTokenClassification,
  322. ElectraModel,
  323. ElectraPreTrainedModel,
  324. load_tf_weights_in_electra,
  325. )
  326. from .modeling_encoder_decoder import EncoderDecoderModel
  327. from .modeling_flaubert import (
  328. FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
  329. FlaubertForMultipleChoice,
  330. FlaubertForQuestionAnswering,
  331. FlaubertForQuestionAnsweringSimple,
  332. FlaubertForSequenceClassification,
  333. FlaubertForTokenClassification,
  334. FlaubertModel,
  335. FlaubertWithLMHeadModel,
  336. )
  337. from .modeling_fsmt import FSMTForConditionalGeneration, FSMTModel, PretrainedFSMTModel
  338. from .modeling_funnel import (
  339. FUNNEL_PRETRAINED_MODEL_ARCHIVE_LIST,
  340. FunnelBaseModel,
  341. FunnelForMaskedLM,
  342. FunnelForMultipleChoice,
  343. FunnelForPreTraining,
  344. FunnelForQuestionAnswering,
  345. FunnelForSequenceClassification,
  346. FunnelForTokenClassification,
  347. FunnelModel,
  348. load_tf_weights_in_funnel,
  349. )
  350. from .modeling_gpt2 import (
  351. GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
  352. GPT2DoubleHeadsModel,
  353. GPT2LMHeadModel,
  354. GPT2Model,
  355. GPT2PreTrainedModel,
  356. load_tf_weights_in_gpt2,
  357. )
  358. from .modeling_layoutlm import (
  359. LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
  360. LayoutLMForMaskedLM,
  361. LayoutLMForTokenClassification,
  362. LayoutLMModel,
  363. )
  364. from .modeling_longformer import (
  365. LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
  366. LongformerForMaskedLM,
  367. LongformerForMultipleChoice,
  368. LongformerForQuestionAnswering,
  369. LongformerForSequenceClassification,
  370. LongformerForTokenClassification,
  371. LongformerModel,
  372. LongformerSelfAttention,
  373. )
  374. from .modeling_lxmert import (
  375. LxmertEncoder,
  376. LxmertForPreTraining,
  377. LxmertForQuestionAnswering,
  378. LxmertModel,
  379. LxmertPreTrainedModel,
  380. LxmertVisualFeatureEncoder,
  381. LxmertXLayer,
  382. )
  383. from .modeling_marian import MarianMTModel
  384. from .modeling_mbart import MBartForConditionalGeneration
  385. from .modeling_mmbt import MMBTForClassification, MMBTModel, ModalEmbeddings
  386. from .modeling_mobilebert import (
  387. MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
  388. MobileBertForMaskedLM,
  389. MobileBertForMultipleChoice,
  390. MobileBertForNextSentencePrediction,
  391. MobileBertForPreTraining,
  392. MobileBertForQuestionAnswering,
  393. MobileBertForSequenceClassification,
  394. MobileBertForTokenClassification,
  395. MobileBertLayer,
  396. MobileBertModel,
  397. MobileBertPreTrainedModel,
  398. load_tf_weights_in_mobilebert,
  399. )
  400. from .modeling_openai import (
  401. OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
  402. OpenAIGPTDoubleHeadsModel,
  403. OpenAIGPTLMHeadModel,
  404. OpenAIGPTModel,
  405. OpenAIGPTPreTrainedModel,
  406. load_tf_weights_in_openai_gpt,
  407. )
  408. from .modeling_pegasus import PegasusForConditionalGeneration
  409. from .modeling_rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration
  410. from .modeling_reformer import (
  411. REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
  412. ReformerAttention,
  413. ReformerForMaskedLM,
  414. ReformerForQuestionAnswering,
  415. ReformerForSequenceClassification,
  416. ReformerLayer,
  417. ReformerModel,
  418. ReformerModelWithLMHead,
  419. )
  420. from .modeling_retribert import RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST, RetriBertModel, RetriBertPreTrainedModel
  421. from .modeling_roberta import (
  422. ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
  423. RobertaForCausalLM,
  424. RobertaForMaskedLM,
  425. RobertaForMultipleChoice,
  426. RobertaForQuestionAnswering,
  427. RobertaForSequenceClassification,
  428. RobertaForTokenClassification,
  429. RobertaModel,
  430. )
  431. from .modeling_t5 import (
  432. T5_PRETRAINED_MODEL_ARCHIVE_LIST,
  433. T5ForConditionalGeneration,
  434. T5Model,
  435. T5PreTrainedModel,
  436. load_tf_weights_in_t5,
  437. )
  438. from .modeling_transfo_xl import (
  439. TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
  440. AdaptiveEmbedding,
  441. TransfoXLLMHeadModel,
  442. TransfoXLModel,
  443. TransfoXLPreTrainedModel,
  444. load_tf_weights_in_transfo_xl,
  445. )
  446. from .modeling_utils import Conv1D, PreTrainedModel, apply_chunking_to_forward, prune_layer
  447. from .modeling_xlm import (
  448. XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
  449. XLMForMultipleChoice,
  450. XLMForQuestionAnswering,
  451. XLMForQuestionAnsweringSimple,
  452. XLMForSequenceClassification,
  453. XLMForTokenClassification,
  454. XLMModel,
  455. XLMPreTrainedModel,
  456. XLMWithLMHeadModel,
  457. )
  458. from .modeling_xlm_roberta import (
  459. XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
  460. XLMRobertaForCausalLM,
  461. XLMRobertaForMaskedLM,
  462. XLMRobertaForMultipleChoice,
  463. XLMRobertaForQuestionAnswering,
  464. XLMRobertaForSequenceClassification,
  465. XLMRobertaForTokenClassification,
  466. XLMRobertaModel,
  467. )
  468. from .modeling_xlnet import (
  469. XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
  470. XLNetForMultipleChoice,
  471. XLNetForQuestionAnswering,
  472. XLNetForQuestionAnsweringSimple,
  473. XLNetForSequenceClassification,
  474. XLNetForTokenClassification,
  475. XLNetLMHeadModel,
  476. XLNetModel,
  477. XLNetPreTrainedModel,
  478. load_tf_weights_in_xlnet,
  479. )
  480. # Optimization
  481. from .optimization import (
  482. Adafactor,
  483. AdamW,
  484. get_constant_schedule,
  485. get_constant_schedule_with_warmup,
  486. get_cosine_schedule_with_warmup,
  487. get_cosine_with_hard_restarts_schedule_with_warmup,
  488. get_linear_schedule_with_warmup,
  489. get_polynomial_decay_schedule_with_warmup,
  490. )
  491. from .tokenization_marian import MarianTokenizer
  492. # Trainer
  493. from .trainer import EvalPrediction, Trainer, set_seed, torch_distributed_zero_first
  494. # TensorFlow
  495. if is_tf_available():
  496. from .benchmark.benchmark_args_tf import TensorFlowBenchmarkArguments
  497. # Benchmarks
  498. from .benchmark.benchmark_tf import TensorFlowBenchmark
  499. from .generation_tf_utils import tf_top_k_top_p_filtering
  500. from .modeling_tf_albert import (
  501. TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
  502. TFAlbertForMaskedLM,
  503. TFAlbertForMultipleChoice,
  504. TFAlbertForPreTraining,
  505. TFAlbertForQuestionAnswering,
  506. TFAlbertForSequenceClassification,
  507. TFAlbertForTokenClassification,
  508. TFAlbertMainLayer,
  509. TFAlbertModel,
  510. TFAlbertPreTrainedModel,
  511. )
  512. from .modeling_tf_auto import (
  513. TF_MODEL_FOR_CAUSAL_LM_MAPPING,
  514. TF_MODEL_FOR_MASKED_LM_MAPPING,
  515. TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
  516. TF_MODEL_FOR_PRETRAINING_MAPPING,
  517. TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
  518. TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
  519. TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
  520. TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
  521. TF_MODEL_MAPPING,
  522. TF_MODEL_WITH_LM_HEAD_MAPPING,
  523. TFAutoModel,
  524. TFAutoModelForCausalLM,
  525. TFAutoModelForMaskedLM,
  526. TFAutoModelForMultipleChoice,
  527. TFAutoModelForPreTraining,
  528. TFAutoModelForQuestionAnswering,
  529. TFAutoModelForSeq2SeqLM,
  530. TFAutoModelForSequenceClassification,
  531. TFAutoModelForTokenClassification,
  532. TFAutoModelWithLMHead,
  533. )
  534. from .modeling_tf_bert import (
  535. TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
  536. TFBertEmbeddings,
  537. TFBertForMaskedLM,
  538. TFBertForMultipleChoice,
  539. TFBertForNextSentencePrediction,
  540. TFBertForPreTraining,
  541. TFBertForQuestionAnswering,
  542. TFBertForSequenceClassification,
  543. TFBertForTokenClassification,
  544. TFBertLMHeadModel,
  545. TFBertMainLayer,
  546. TFBertModel,
  547. TFBertPreTrainedModel,
  548. )
  549. from .modeling_tf_camembert import (
  550. TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
  551. TFCamembertForMaskedLM,
  552. TFCamembertForMultipleChoice,
  553. TFCamembertForQuestionAnswering,
  554. TFCamembertForSequenceClassification,
  555. TFCamembertForTokenClassification,
  556. TFCamembertModel,
  557. )
  558. from .modeling_tf_ctrl import (
  559. TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST,
  560. TFCTRLLMHeadModel,
  561. TFCTRLModel,
  562. TFCTRLPreTrainedModel,
  563. )
  564. from .modeling_tf_distilbert import (
  565. TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
  566. TFDistilBertForMaskedLM,
  567. TFDistilBertForMultipleChoice,
  568. TFDistilBertForQuestionAnswering,
  569. TFDistilBertForSequenceClassification,
  570. TFDistilBertForTokenClassification,
  571. TFDistilBertMainLayer,
  572. TFDistilBertModel,
  573. TFDistilBertPreTrainedModel,
  574. )
  575. from .modeling_tf_electra import (
  576. TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
  577. TFElectraForMaskedLM,
  578. TFElectraForMultipleChoice,
  579. TFElectraForPreTraining,
  580. TFElectraForQuestionAnswering,
  581. TFElectraForSequenceClassification,
  582. TFElectraForTokenClassification,
  583. TFElectraModel,
  584. TFElectraPreTrainedModel,
  585. )
  586. from .modeling_tf_flaubert import (
  587. TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
  588. TFFlaubertForMultipleChoice,
  589. TFFlaubertForQuestionAnsweringSimple,
  590. TFFlaubertForSequenceClassification,
  591. TFFlaubertForTokenClassification,
  592. TFFlaubertModel,
  593. TFFlaubertWithLMHeadModel,
  594. )
  595. from .modeling_tf_funnel import (
  596. TF_FUNNEL_PRETRAINED_MODEL_ARCHIVE_LIST,
  597. TFFunnelBaseModel,
  598. TFFunnelForMaskedLM,
  599. TFFunnelForMultipleChoice,
  600. TFFunnelForPreTraining,
  601. TFFunnelForQuestionAnswering,
  602. TFFunnelForSequenceClassification,
  603. TFFunnelForTokenClassification,
  604. TFFunnelModel,
  605. )
  606. from .modeling_tf_gpt2 import (
  607. TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
  608. TFGPT2DoubleHeadsModel,
  609. TFGPT2LMHeadModel,
  610. TFGPT2MainLayer,
  611. TFGPT2Model,
  612. TFGPT2PreTrainedModel,
  613. )
  614. from .modeling_tf_longformer import (
  615. TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
  616. TFLongformerForMaskedLM,
  617. TFLongformerForQuestionAnswering,
  618. TFLongformerModel,
  619. TFLongformerSelfAttention,
  620. )
  621. from .modeling_tf_lxmert import (
  622. TF_LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST,
  623. TFLxmertForPreTraining,
  624. TFLxmertMainLayer,
  625. TFLxmertModel,
  626. TFLxmertPreTrainedModel,
  627. TFLxmertVisualFeatureEncoder,
  628. )
  629. from .modeling_tf_mobilebert import (
  630. TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
  631. TFMobileBertForMaskedLM,
  632. TFMobileBertForMultipleChoice,
  633. TFMobileBertForNextSentencePrediction,
  634. TFMobileBertForPreTraining,
  635. TFMobileBertForQuestionAnswering,
  636. TFMobileBertForSequenceClassification,
  637. TFMobileBertForTokenClassification,
  638. TFMobileBertMainLayer,
  639. TFMobileBertModel,
  640. TFMobileBertPreTrainedModel,
  641. )
  642. from .modeling_tf_openai import (
  643. TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
  644. TFOpenAIGPTDoubleHeadsModel,
  645. TFOpenAIGPTLMHeadModel,
  646. TFOpenAIGPTMainLayer,
  647. TFOpenAIGPTModel,
  648. TFOpenAIGPTPreTrainedModel,
  649. )
  650. from .modeling_tf_roberta import (
  651. TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
  652. TFRobertaForMaskedLM,
  653. TFRobertaForMultipleChoice,
  654. TFRobertaForQuestionAnswering,
  655. TFRobertaForSequenceClassification,
  656. TFRobertaForTokenClassification,
  657. TFRobertaMainLayer,
  658. TFRobertaModel,
  659. TFRobertaPreTrainedModel,
  660. )
  661. from .modeling_tf_t5 import (
  662. TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST,
  663. TFT5ForConditionalGeneration,
  664. TFT5Model,
  665. TFT5PreTrainedModel,
  666. )
  667. from .modeling_tf_transfo_xl import (
  668. TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
  669. TFAdaptiveEmbedding,
  670. TFTransfoXLLMHeadModel,
  671. TFTransfoXLMainLayer,
  672. TFTransfoXLModel,
  673. TFTransfoXLPreTrainedModel,
  674. )
  675. from .modeling_tf_utils import TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, shape_list
  676. from .modeling_tf_xlm import (
  677. TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
  678. TFXLMForMultipleChoice,
  679. TFXLMForQuestionAnsweringSimple,
  680. TFXLMForSequenceClassification,
  681. TFXLMForTokenClassification,
  682. TFXLMMainLayer,
  683. TFXLMModel,
  684. TFXLMPreTrainedModel,
  685. TFXLMWithLMHeadModel,
  686. )
  687. from .modeling_tf_xlm_roberta import (
  688. TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
  689. TFXLMRobertaForMaskedLM,
  690. TFXLMRobertaForMultipleChoice,
  691. TFXLMRobertaForQuestionAnswering,
  692. TFXLMRobertaForSequenceClassification,
  693. TFXLMRobertaForTokenClassification,
  694. TFXLMRobertaModel,
  695. )
  696. from .modeling_tf_xlnet import (
  697. TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
  698. TFXLNetForMultipleChoice,
  699. TFXLNetForQuestionAnsweringSimple,
  700. TFXLNetForSequenceClassification,
  701. TFXLNetForTokenClassification,
  702. TFXLNetLMHeadModel,
  703. TFXLNetMainLayer,
  704. TFXLNetModel,
  705. TFXLNetPreTrainedModel,
  706. )
  707. # Optimization
  708. from .optimization_tf import AdamWeightDecay, GradientAccumulator, WarmUp, create_optimizer
  709. # Trainer
  710. from .trainer_tf import TFTrainer
  711. if not is_tf_available() and not is_torch_available():
  712. logger.warning(
  713. "Neither PyTorch nor TensorFlow >= 2.0 have been found."
  714. "Models won't be available and only tokenizers, configuration"
  715. "and file/data utilities can be used."
  716. )