classVertexAIWorker(BaseWorker):"""Prefect worker that executes flow runs within Vertex AI Jobs."""type="vertex-ai"job_configuration=VertexAIWorkerJobConfigurationjob_configuration_variables=VertexAIWorkerVariables_description=("Execute flow runs within containers on Google Vertex AI. Requires ""a Google Cloud Platform account.")_display_name="Google Vertex AI"_documentation_url="https://prefecthq.github.io/prefect-gcp/vertex_worker/"_logo_url="https://cdn.sanity.io/images/3ugk85nk/production/10424e311932e31c477ac2b9ef3d53cefbaad708-250x250.png"# noqaasyncdefrun(self,flow_run:"FlowRun",configuration:VertexAIWorkerJobConfiguration,task_status:Optional[anyio.abc.TaskStatus]=None,)->VertexAIWorkerResult:""" Executes a flow run within a Vertex AI Job and waits for the flow run to complete. Args: flow_run: The flow run to execute configuration: The configuration to use when executing the flow run. task_status: The task status object for the current flow run. If provided, the task will be marked as started. Returns: VertexAIWorkerResult: A result object containing information about the final state of the flow run """logger=self.get_flow_run_logger(flow_run)client_options=ClientOptions(api_endpoint=f"{configuration.region}-aiplatform.googleapis.com")job_name=configuration.job_namejob_spec=self._build_job_spec(configuration)job_service_async_client=(configuration.credentials.get_job_service_async_client(client_options=client_options))job_run=awaitself._create_and_begin_job(job_name,job_spec,job_service_async_client,configuration,logger,)iftask_status:task_status.started(job_run.name)final_job_run=awaitself._watch_job_run(job_name=job_name,full_job_name=job_run.name,job_service_async_client=job_service_async_client,current_state=job_run.state,until_states=(JobState.JOB_STATE_SUCCEEDED,JobState.JOB_STATE_FAILED,JobState.JOB_STATE_CANCELLED,JobState.JOB_STATE_EXPIRED,),configuration=configuration,logger=logger,timeout=int(datetime.timedelta(hours=configuration.job_spec["maximum_run_time_hours"]).total_seconds()),)error_msg=final_job_run.error.message# Vertex will include an error message upon valid# flow cancellations, so we'll avoid raising an error in that caseiferror_msgand"CANCELED"notinerror_msg:raiseRuntimeError(error_msg)status_code=0iffinal_job_run.state==JobState.JOB_STATE_SUCCEEDEDelse1returnVertexAIWorkerResult(identifier=final_job_run.display_name,status_code=status_code)def_build_job_spec(self,configuration:VertexAIWorkerJobConfiguration)->"CustomJobSpec":""" Builds a job spec by gathering details. """# here, we extract the `worker_pool_specs` out of the job_specworker_pool_specs=[WorkerPoolSpec(container_spec=ContainerSpec(**spec["container_spec"]),machine_spec=MachineSpec(**spec["machine_spec"]),replica_count=spec["replica_count"],disk_spec=DiskSpec(**spec["disk_spec"]),)forspecinconfiguration.job_spec.pop("worker_pool_specs",[])]timeout=Duration().FromTimedelta(td=datetime.timedelta(hours=configuration.job_spec["maximum_run_time_hours"]))scheduling=Scheduling(timeout=timeout)# construct the final job spec that we will provide to Vertex AIjob_spec=CustomJobSpec(worker_pool_specs=worker_pool_specs,scheduling=scheduling,ignore_unknown_fields=True,**configuration.job_spec,)returnjob_specasyncdef_create_and_begin_job(self,job_name:str,job_spec:"CustomJobSpec",job_service_async_client:"JobServiceAsyncClient",configuration:VertexAIWorkerJobConfiguration,logger:PrefectLogAdapter,)->"CustomJob":""" Builds a custom job and begins running it. """# create custom jobcustom_job=CustomJob(display_name=job_name,job_spec=job_spec,labels=self._get_compatible_labels(configuration=configuration),)# run joblogger.info(f"Creating job {job_name!r}")project=configuration.projectresource_name=f"projects/{project}/locations/{configuration.region}"asyncforattemptinAsyncRetrying(stop=stop_after_attempt(3),wait=wait_fixed(1)+wait_random(0,3)):withattempt:custom_job_run=awaitjob_service_async_client.create_custom_job(parent=resource_name,custom_job=custom_job,)logger.info(f"Job {job_name!r} created. "f"The full job name is {custom_job_run.name!r}")returncustom_job_runasyncdef_watch_job_run(self,job_name:str,full_job_name:str,# different from job_namejob_service_async_client:"JobServiceAsyncClient",current_state:"JobState",until_states:Tuple["JobState"],configuration:VertexAIWorkerJobConfiguration,logger:PrefectLogAdapter,timeout:int=None,)->"CustomJob":""" Polls job run to see if status changed. State changes reported by the Vertex AI API may sometimes be inaccurate immediately upon startup, but should eventually report a correct running and then terminal state. The minimum training duration for a custom job is 30 seconds, so short-lived jobs may be marked as successful some time after a flow run has completed. """state=JobState.JOB_STATE_UNSPECIFIEDlast_state=current_statet0=time.time()whilestatenotinuntil_states:job_run=awaitjob_service_async_client.get_custom_job(name=full_job_name,)state=job_run.stateifstate!=last_state:state_label=(state.name.replace("_"," ").lower().replace("state","state is now:"))# results in "New job state is now: succeeded"logger.debug(f"{job_name} has new {state_label}")last_state=stateelse:# Intermittently, the job will not be described. We want to respect the# watch timeout though.logger.debug(f"Job {job_name} not found.")elapsed_time=time.time()-t0iftimeoutisnotNoneandelapsed_time>timeout:raiseRuntimeError(f"Timed out after {elapsed_time}s while watching job for states ""{until_states!r}")awaitasyncio.sleep(configuration.job_watch_poll_interval)returnjob_rundef_get_compatible_labels(self,configuration:VertexAIWorkerJobConfiguration)->Dict[str,str]:""" Ensures labels are compatible with GCP label requirements. https://cloud.google.com/resource-manager/docs/creating-managing-labels Ex: the Prefect provided key of prefect.io/flow-name -> prefect-io_flow-name """compatible_labels={}forkey,valinconfiguration.labels.items():new_key=slugify(key,lowercase=True,replacements=[("/","_"),(".","-")],max_length=63,regex_pattern=_DISALLOWED_GCP_LABEL_CHARACTERS,)compatible_labels[new_key]=slugify(val,lowercase=True,replacements=[("/","_"),(".","-")],max_length=63,regex_pattern=_DISALLOWED_GCP_LABEL_CHARACTERS,)returncompatible_labelsasyncdefkill_infrastructure(self,infrastructure_pid:str,configuration:VertexAIWorkerJobConfiguration,grace_seconds:int=30,):""" Stops a job running in Vertex AI upon flow cancellation, based on the provided infrastructure PID + run configuration. """ifgrace_seconds!=30:self._logger.warning(f"Kill grace period of {grace_seconds}s requested, but GCP does not ""support dynamic grace period configuration. See here for more info: ""https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.customJobs/cancel"# noqa)client_options=ClientOptions(api_endpoint=f"{configuration.region}-aiplatform.googleapis.com")job_service_async_client=(configuration.credentials.get_job_service_async_client(client_options=client_options))awaitself._stop_job(client=job_service_async_client,vertex_job_name=infrastructure_pid,)asyncdef_stop_job(self,client:"JobServiceAsyncClient",vertex_job_name:str):""" Calls the `cancel_custom_job` method on the Vertex AI Job Service Client. """cancel_custom_job_request=CancelCustomJobRequest(name=vertex_job_name)try:awaitclient.cancel_custom_job(request=cancel_custom_job_request,)exceptExceptionasexc:if"does not exist"instr(exc):raiseInfrastructureNotFound(f"Cannot stop Vertex AI job; the job name {vertex_job_name!r} ""could not be found.")fromexcraise
asyncdefkill_infrastructure(self,infrastructure_pid:str,configuration:VertexAIWorkerJobConfiguration,grace_seconds:int=30,):""" Stops a job running in Vertex AI upon flow cancellation, based on the provided infrastructure PID + run configuration. """ifgrace_seconds!=30:self._logger.warning(f"Kill grace period of {grace_seconds}s requested, but GCP does not ""support dynamic grace period configuration. See here for more info: ""https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.customJobs/cancel"# noqa)client_options=ClientOptions(api_endpoint=f"{configuration.region}-aiplatform.googleapis.com")job_service_async_client=(configuration.credentials.get_job_service_async_client(client_options=client_options))awaitself._stop_job(client=job_service_async_client,vertex_job_name=infrastructure_pid,)
asyncdefrun(self,flow_run:"FlowRun",configuration:VertexAIWorkerJobConfiguration,task_status:Optional[anyio.abc.TaskStatus]=None,)->VertexAIWorkerResult:""" Executes a flow run within a Vertex AI Job and waits for the flow run to complete. Args: flow_run: The flow run to execute configuration: The configuration to use when executing the flow run. task_status: The task status object for the current flow run. If provided, the task will be marked as started. Returns: VertexAIWorkerResult: A result object containing information about the final state of the flow run """logger=self.get_flow_run_logger(flow_run)client_options=ClientOptions(api_endpoint=f"{configuration.region}-aiplatform.googleapis.com")job_name=configuration.job_namejob_spec=self._build_job_spec(configuration)job_service_async_client=(configuration.credentials.get_job_service_async_client(client_options=client_options))job_run=awaitself._create_and_begin_job(job_name,job_spec,job_service_async_client,configuration,logger,)iftask_status:task_status.started(job_run.name)final_job_run=awaitself._watch_job_run(job_name=job_name,full_job_name=job_run.name,job_service_async_client=job_service_async_client,current_state=job_run.state,until_states=(JobState.JOB_STATE_SUCCEEDED,JobState.JOB_STATE_FAILED,JobState.JOB_STATE_CANCELLED,JobState.JOB_STATE_EXPIRED,),configuration=configuration,logger=logger,timeout=int(datetime.timedelta(hours=configuration.job_spec["maximum_run_time_hours"]).total_seconds()),)error_msg=final_job_run.error.message# Vertex will include an error message upon valid# flow cancellations, so we'll avoid raising an error in that caseiferror_msgand"CANCELED"notinerror_msg:raiseRuntimeError(error_msg)status_code=0iffinal_job_run.state==JobState.JOB_STATE_SUCCEEDEDelse1returnVertexAIWorkerResult(identifier=final_job_run.display_name,status_code=status_code)
Configuration class used by the Vertex AI Worker to create a Job.
An instance of this class is passed to the Vertex AI Worker's run method
for each flow run. It contains all information necessary to execute
the flow run as a Vertex AI Job.
classVertexAIWorkerJobConfiguration(BaseJobConfiguration):""" Configuration class used by the Vertex AI Worker to create a Job. An instance of this class is passed to the Vertex AI Worker's `run` method for each flow run. It contains all information necessary to execute the flow run as a Vertex AI Job. Attributes: region: The region where the Vertex AI Job resides. credentials: The GCP Credentials used to connect to Vertex AI. job_spec: The Vertex AI Job spec used to create the Job. job_watch_poll_interval: The interval between GCP API calls to check Job state. """region:str=Field(description="The region where the Vertex AI Job resides.",example="us-central1",)credentials:Optional[GcpCredentials]=Field(title="GCP Credentials",default_factory=GcpCredentials,description="The GCP Credentials used to initiate the ""Vertex AI Job. If not provided credentials will be ""inferred from the local environment.",)job_spec:Dict[str,Any]=Field(template={"service_account_name":"{{ service_account_name }}","network":"{{ network }}","reserved_ip_ranges":"{{ reserved_ip_ranges }}","maximum_run_time_hours":"{{ maximum_run_time_hours }}","worker_pool_specs":[{"replica_count":1,"container_spec":{"image_uri":"{{ image }}","command":"{{ command }}","args":[],},"machine_spec":{"machine_type":"{{ machine_type }}","accelerator_type":"{{ accelerator_type }}","accelerator_count":"{{ accelerator_count }}",},"disk_spec":{"boot_disk_type":"{{ boot_disk_type }}","boot_disk_size_gb":"{{ boot_disk_size_gb }}",},}],})job_watch_poll_interval:float=Field(default=5.0,title="Poll Interval (Seconds)",description=("The amount of time to wait between GCP API calls while monitoring the ""state of a Vertex AI Job."),)@propertydefproject(self)->str:"""property for accessing the project from the credentials."""returnself.credentials.project@propertydefjob_name(self)->str:""" The name can be up to 128 characters long and can be consist of any UTF-8 characters. Reference: https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.CustomJob#google_cloud_aiplatform_CustomJob_display_name """# noqaunique_suffix=uuid4().hexjob_name=f"{self.name}-{unique_suffix}"returnjob_namedefprepare_for_flow_run(self,flow_run:"FlowRun",deployment:Optional["DeploymentResponse"]=None,flow:Optional["Flow"]=None,):super().prepare_for_flow_run(flow_run,deployment,flow)self._inject_formatted_env_vars()self._inject_formatted_command()self._ensure_existence_of_service_account()def_inject_formatted_env_vars(self):"""Inject environment variables in the Vertex job_spec configuration, in the correct format, which is sourced from the BaseJobConfiguration. This method is invoked by `prepare_for_flow_run()`."""worker_pool_specs=self.job_spec["worker_pool_specs"]formatted_env_vars=[{"name":key,"value":value}forkey,valueinself.env.items()]worker_pool_specs[0]["container_spec"]["env"]=formatted_env_varsdef_inject_formatted_command(self):"""Inject shell commands in the Vertex job_spec configuration, in the correct format, which is sourced from the BaseJobConfiguration. Here, we'll ensure that the default string format is converted to a list of strings."""worker_pool_specs=self.job_spec["worker_pool_specs"]existing_command=worker_pool_specs[0]["container_spec"].get("command")ifexisting_commandisNone:worker_pool_specs[0]["container_spec"]["command"]=shlex.split(self._base_flow_run_command())elifisinstance(existing_command,str):worker_pool_specs[0]["container_spec"]["command"]=shlex.split(existing_command)def_ensure_existence_of_service_account(self):"""Verify that a service account was provided, either in the credentials or as a standalone service account name override."""provided_service_account_name=self.job_spec.get("service_account_name")credential_service_account=self.credentials._service_account_emailservice_account_to_use=(provided_service_account_nameorcredential_service_account)ifservice_account_to_useisNone:raiseValueError("A service account is required for the Vertex job. ""A service account could not be detected in the attached credentials ""or in the service_account_name input. ""Please pass in valid GCP credentials or a valid service_account_name")self.job_spec["service_account_name"]=service_account_to_use@validator("job_spec")def_ensure_job_spec_includes_required_attributes(cls,value:Dict[str,Any]):""" Ensures that the job spec includes all required components. """patch=JsonPatch.from_diff(value,_get_base_job_spec())missing_paths=sorted([op["path"]foropinpatchifop["op"]=="add"])ifmissing_paths:raiseValueError("Job is missing required attributes at the following paths: "f"{', '.join(missing_paths)}")returnvalue
The name can be up to 128 characters long and can be consist of any UTF-8 characters. Reference:
https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.CustomJob#google_cloud_aiplatform_CustomJob_display_name
classVertexAIWorkerVariables(BaseVariables):""" Default variables for the Vertex AI worker. The schema for this class is used to populate the `variables` section of the default base job template. """region:str=Field(description="The region where the Vertex AI Job resides.",example="us-central1",)image:str=Field(title="Image Name",description=("The URI of a container image in the Container or Artifact Registry, ""used to run your Vertex AI Job. Note that Vertex AI will need access""to the project and region where the container image is stored. See ""https://cloud.google.com/vertex-ai/docs/training/create-custom-container"),example="gcr.io/your-project/your-repo:latest",)credentials:Optional[GcpCredentials]=Field(title="GCP Credentials",default_factory=GcpCredentials,description="The GCP Credentials used to initiate the ""Vertex AI Job. If not provided credentials will be ""inferred from the local environment.",)machine_type:str=Field(title="Machine Type",description=("The machine type to use for the run, which controls ""the available CPU and memory. ""See https://cloud.google.com/vertex-ai/docs/reference/rest/v1/MachineSpec"),default="n1-standard-4",)accelerator_type:Optional[str]=Field(title="Accelerator Type",description=("The type of accelerator to attach to the machine. ""See https://cloud.google.com/vertex-ai/docs/reference/rest/v1/MachineSpec"),example="NVIDIA_TESLA_K80",default=None,)accelerator_count:Optional[int]=Field(title="Accelerator Count",description=("The number of accelerators to attach to the machine. ""See https://cloud.google.com/vertex-ai/docs/reference/rest/v1/MachineSpec"),example=1,default=None,)boot_disk_type:str=Field(title="Boot Disk Type",description="The type of boot disk to attach to the machine.",default="pd-ssd",)boot_disk_size_gb:int=Field(title="Boot Disk Size (GB)",description="The size of the boot disk to attach to the machine, in gigabytes.",default=100,)maximum_run_time_hours:int=Field(default=1,title="Maximum Run Time (Hours)",description="The maximum job running time, in hours",)network:Optional[str]=Field(default=None,title="Network",description="The full name of the Compute Engine network""to which the Job should be peered. Private services access must ""already be configured for the network. If left unspecified, the job ""is not peered with any network. ""For example: projects/12345/global/networks/myVPC",)reserved_ip_ranges:Optional[List[str]]=Field(default=None,title="Reserved IP Ranges",description="A list of names for the reserved ip ranges under the VPC ""network that can be used for this job. If set, we will deploy the job ""within the provided ip ranges. Otherwise, the job will be deployed to ""any ip ranges under the provided VPC network.",)service_account_name:Optional[str]=Field(default=None,title="Service Account Name",description=("Specifies the service account to use ""as the run-as account in Vertex AI. The worker submitting jobs must have ""act-as permission on this run-as account. If unspecified, the AI ""Platform Custom Code Service Agent for the CustomJob's project is ""used. Takes precedence over the service account found in GCP credentials, ""and required if a service account cannot be detected in GCP credentials."),)job_watch_poll_interval:float=Field(default=5.0,title="Poll Interval (Seconds)",description=("The amount of time to wait between GCP API calls while monitoring the ""state of a Vertex AI Job."),)