Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions alchemiscale/compute/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,8 @@ async def set_task_result(

protocoldagresult_ = body_["protocoldagresult"]
compute_service_id = body_["compute_service_id"]
stdout = body_.get("stdout")
stderr = body_.get("stderr")

task_sk = ScopedKey.from_str(task_scoped_key)
validate_scopes(task_sk.scope, token)
Expand All @@ -391,6 +393,29 @@ async def set_task_result(
task=task_sk, protocoldagresultref=protocoldagresultref
)

# Store stdout and stderr logs if provided
if stdout is not None:
stdout_log_ref = s3os.push_protocoldagresult_log(
log_content=stdout,
stream="stdout",
protocoldagresult_gufekey=pdr.key,
transformation=tf_sk,
protocoldagresult_ok=pdr.ok(),
creator=compute_service_id,
)
n4js.set_protocoldagresult_log(result_sk, stdout_log_ref)

if stderr is not None:
stderr_log_ref = s3os.push_protocoldagresult_log(
log_content=stderr,
stream="stderr",
protocoldagresult_gufekey=pdr.key,
transformation=tf_sk,
protocoldagresult_ok=pdr.ok(),
creator=compute_service_id,
)
n4js.set_protocoldagresult_log(result_sk, stderr_log_ref)

# if success, set task complete, remove from all hubs
# otherwise, set as errored, leave in hubs
if protocoldagresultref.ok:
Expand Down
8 changes: 8 additions & 0 deletions alchemiscale/compute/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,21 @@ def set_task_result(
task: ScopedKey,
protocoldagresult: ProtocolDAGResult,
compute_service_id: ComputeServiceID | None = None,
stdout: str | None = None,
stderr: str | None = None,
) -> ScopedKey:

data = dict(
protocoldagresult=compress_gufe_zstd(protocoldagresult),
compute_service_id=str(compute_service_id),
)

# Add logs if provided
if stdout is not None:
data["stdout"] = stdout
if stderr is not None:
data["stderr"] = stderr

pdr_sk = self._post_resource(f"/tasks/{task}/results", data)

return ScopedKey.from_dict(pdr_sk)
Expand Down
64 changes: 51 additions & 13 deletions alchemiscale/compute/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,25 @@
import threading
from pathlib import Path
import shutil
import sys
import io
from contextlib import redirect_stdout, redirect_stderr


class _TeeStream(io.TextIOBase):
"""A stream wrapper that writes to both a capture buffer and the original stream."""

def __init__(self, capture: io.StringIO, original: io.TextIOBase):
self._capture = capture
self._original = original

def write(self, s):
self._capture.write(s)
return self._original.write(s)

def flush(self):
self._capture.flush()
self._original.flush()

from gufe import Transformation
from gufe.protocols.protocoldag import execute_DAG, ProtocolDAG, ProtocolDAGResult
Expand Down Expand Up @@ -198,16 +217,20 @@ def task_to_protocoldag(
return protocoldag, transformation, extends_protocoldagresult

def push_result(
self, task: ScopedKey, protocoldagresult: ProtocolDAGResult
self,
task: ScopedKey,
protocoldagresult: ProtocolDAGResult,
stdout: str | None = None,
stderr: str | None = None,
) -> ScopedKey:
# TODO: this method should postprocess any paths,
# leaf nodes in DAG for blob results that should go to object store

# TODO: ship paths to object store

# finally, push ProtocolDAGResult
# finally, push ProtocolDAGResult with logs
sk: ScopedKey = self.client.set_task_result(
task, protocoldagresult, self.compute_service_id
task, protocoldagresult, self.compute_service_id, stdout=stdout, stderr=stderr
)

return sk
Expand Down Expand Up @@ -237,22 +260,35 @@ def execute(self, task: ScopedKey) -> ScopedKey:
scratch.mkdir()

self.logger.info("Executing '%s'...", protocoldag)

# Capture stdout and stderr during execution.
# Use _TeeStream for stderr so that logging output (which goes to
# stderr via the StreamHandler) is both captured and still emitted.
stdout_capture = io.StringIO()
stderr_capture = io.StringIO()
stderr_tee = _TeeStream(stderr_capture, sys.stderr)

try:
protocoldagresult = execute_DAG(
protocoldag,
shared_basedir=shared,
scratch_basedir=scratch,
keep_scratch=self.keep_scratch,
raise_error=False,
n_retries=self.settings.n_retries,
)
with redirect_stdout(stdout_capture), redirect_stderr(stderr_tee):
protocoldagresult = execute_DAG(
protocoldag,
shared_basedir=shared,
scratch_basedir=scratch,
keep_scratch=self.keep_scratch,
raise_error=False,
n_retries=self.settings.n_retries,
)
finally:
if not self.keep_shared:
shutil.rmtree(shared)

if not self.keep_scratch:
shutil.rmtree(scratch)

# Get captured output
stdout_content = stdout_capture.getvalue()
stderr_content = stderr_capture.getvalue()

if protocoldagresult.ok():
self.logger.info("'%s' -> '%s' : SUCCESS", protocoldag, protocoldagresult)
else:
Expand All @@ -265,8 +301,10 @@ def execute(self, task: ScopedKey) -> ScopedKey:
failure.exception,
)

# push the result (or failure) back to the compute API
result_sk = self.push_result(task, protocoldagresult)
# push the result (or failure) back to the compute API with captured logs
result_sk = self.push_result(
task, protocoldagresult, stdout=stdout_content, stderr=stderr_content
)
self.logger.info("Pushed result `%s'", protocoldagresult)

return result_sk
Expand Down
44 changes: 44 additions & 0 deletions alchemiscale/interface/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,6 +1176,50 @@ def get_task_failures(
return [str(sk) for sk in n4js.get_task_failures(sk)]


@router.get("/tasks/{task_scoped_key}/logs/{stream}")
def get_task_logs(
task_scoped_key,
stream: str,
*,
n4js: Neo4jStore = Depends(get_n4js_depends),
s3os: S3ObjectStore = Depends(get_s3os_depends),
token: TokenData = Depends(get_token_data_depends),
):
"""Get log content for a Task.

Parameters
----------
task_scoped_key
The ScopedKey of the Task.
stream
Either "stdout" or "stderr".

Returns
-------
list[str]
List of log contents from all ProtocolDAGResults for this Task.
"""
if stream not in ["stdout", "stderr"]:
raise HTTPException(
status_code=http_status.HTTP_400_BAD_REQUEST,
detail="stream must be 'stdout' or 'stderr'",
)

sk = ScopedKey.from_str(task_scoped_key)
validate_scopes(sk.scope, token)

# Get all log S3 locations for this task and stream
log_locations = n4js.get_task_log_locations(sk, stream=stream)

# Retrieve the actual log content from the object store
logs = []
for location in log_locations:
log_content = s3os.pull_protocoldagresult_log(location=location)
logs.append(log_content)

return logs


### strategies


Expand Down
36 changes: 36 additions & 0 deletions alchemiscale/interface/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1973,6 +1973,42 @@ def get_task_failures(

return pdrs

def get_task_stdout(self, task: ScopedKey) -> list[str]:
"""Get stdout logs from all `ProtocolDAGResult`s for the given `Task`.

Parameters
----------
task
The `ScopedKey` of the `Task` to retrieve stdout logs for.

Returns
-------
list[str]
List of stdout log contents from all ProtocolDAGResults for this Task.
Each element corresponds to one execution attempt.

"""
logs = self._get_resource(f"/tasks/{task}/logs/stdout")
return logs

def get_task_stderr(self, task: ScopedKey) -> list[str]:
"""Get stderr logs from all `ProtocolDAGResult`s for the given `Task`.

Parameters
----------
task
The `ScopedKey` of the `Task` to retrieve stderr logs for.

Returns
-------
list[str]
List of stderr log contents from all ProtocolDAGResults for this Task.
Each element corresponds to one execution attempt.

"""
logs = self._get_resource(f"/tasks/{task}/logs/stderr")
return logs

def add_task_restart_patterns(
self,
network_scoped_key: ScopedKey,
Expand Down
48 changes: 48 additions & 0 deletions alchemiscale/storage/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,54 @@ def _from_dict(cls, d):
return super()._from_dict(d_)


class ProtocolDAGResultLog(ObjectStoreRef):
"""Reference to stdout or stderr logs from a ProtocolDAGResult execution."""

stream: str # "stdout" or "stderr"

def __init__(
self,
*,
location: str | None = None,
obj_key: GufeKey,
scope: Scope,
stream: str,
datetime_created: datetime.datetime | None = None,
creator: str | None = None,
):
self.location = location
self.obj_key = GufeKey(obj_key)
self.scope = scope
self.stream = stream
self.datetime_created = datetime_created
self.creator = creator

def _to_dict(self):
return {
"location": self.location,
"obj_key": str(self.obj_key),
"scope": str(self.scope),
"stream": self.stream,
"datetime_created": (
self.datetime_created.isoformat()
if self.datetime_created is not None
else None
),
"creator": self.creator,
}

@classmethod
def _from_dict(cls, d):
d_ = copy(d)
d_["datetime_created"] = (
datetime.datetime.fromisoformat(d["datetime_created"])
if d.get("datetime_created") is not None
else None
)

return super()._from_dict(d_)


class StrategyModeEnum(StrEnum):
full = "full"
partial = "partial"
Expand Down
Loading
Loading