Coverage for src/container_collection/docker/check_docker_job.py: 100%

17 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2024-09-25 18:23 +0000

1from __future__ import annotations 

2 

3from typing import TYPE_CHECKING 

4 

5from prefect.context import TaskRunContext 

6from prefect.states import Failed, State 

7 

8if TYPE_CHECKING: 

9 from docker import APIClient 

10 

11RETRIES_EXCEEDED_EXIT_CODE = 80 

12"""Exit code used when task run retries exceed the maximum retries.""" 

13 

14 

15def check_docker_job(api_client: APIClient, container_id: str, max_retries: int) -> int | State: 

16 """ 

17 Check for exit code of a Docker container. 

18 

19 If this task is running within a Prefect flow, it will use the task run 

20 context to get the current run count. While the run count is below the 

21 maximum number of retries, the task will continue to attempt to get the exit 

22 code, and can be called with a retry delay to periodically check the status 

23 of jobs. 

24 

25 If this task is not running within a Prefect flow, the ``max_retries`` 

26 parameters is ignored. Jobs that are still running will throw an exception. 

27 

28 Parameters 

29 ---------- 

30 api_client 

31 Docker API client. 

32 container_id 

33 ID of container to check. 

34 max_retries 

35 Maximum number of retries. 

36 

37 Returns 

38 ------- 

39 : 

40 Exit code if the job is complete, otherwise throws an exception. 

41 """ 

42 

43 context = TaskRunContext.get() 

44 

45 if context is not None and context.task_run.run_count > max_retries: 

46 return RETRIES_EXCEEDED_EXIT_CODE 

47 

48 status = api_client.containers(all=True, filters={"id": container_id})[0]["State"] 

49 

50 # For jobs that are running, throw the appropriate exception. 

51 if context is not None and status == "running": 

52 return Failed() 

53 if status == "running": 

54 message = "Job is in RUNNING state and does not have exit code." 

55 raise RuntimeError(message) 

56 

57 return api_client.wait(container_id, timeout=1)["StatusCode"]