234 lines
8.9 KiB
Python
234 lines
8.9 KiB
Python
"""Manager of subshells in a kernel."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import typing as t
|
|
import uuid
|
|
from functools import partial
|
|
from threading import Lock, current_thread, main_thread
|
|
|
|
import zmq
|
|
from tornado.ioloop import IOLoop
|
|
|
|
from .socket_pair import SocketPair
|
|
from .subshell import SubshellThread
|
|
from .thread import SHELL_CHANNEL_THREAD_NAME
|
|
|
|
|
|
class SubshellManager:
|
|
"""A manager of subshells.
|
|
|
|
Controls the lifetimes of subshell threads and their associated ZMQ sockets and
|
|
streams. Runs mostly in the shell channel thread.
|
|
|
|
Care needed with threadsafe access here. All write access to the cache occurs in
|
|
the shell channel thread so there is only ever one write access at any one time.
|
|
Reading of cache information can be performed by other threads, so all reads are
|
|
protected by a lock so that they are atomic.
|
|
|
|
Sending reply messages via the shell_socket is wrapped by another lock to protect
|
|
against multiple subshells attempting to send at the same time.
|
|
|
|
.. versionadded:: 7
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
context: zmq.Context[t.Any],
|
|
shell_channel_io_loop: IOLoop,
|
|
shell_socket: zmq.Socket[t.Any],
|
|
):
|
|
"""Initialize the subshell manager."""
|
|
assert current_thread() == main_thread()
|
|
|
|
self._context: zmq.Context[t.Any] = context
|
|
self._shell_channel_io_loop = shell_channel_io_loop
|
|
self._shell_socket = shell_socket
|
|
self._cache: dict[str, SubshellThread] = {}
|
|
self._lock_cache = Lock() # Sync lock across threads when accessing cache.
|
|
|
|
# Inproc socket pair for communication from control thread to shell channel thread,
|
|
# such as for create_subshell_request messages. Reply messages are returned straight away.
|
|
self.control_to_shell_channel = SocketPair(self._context, "control")
|
|
self.control_to_shell_channel.on_recv(
|
|
self._shell_channel_io_loop, self._process_control_request, copy=True
|
|
)
|
|
|
|
# Inproc socket pair for communication from shell channel thread to main thread,
|
|
# such as for execute_request messages.
|
|
self._shell_channel_to_main = SocketPair(self._context, "main")
|
|
|
|
# Inproc socket pair for communication from main thread to shell channel thread.
|
|
# such as for execute_reply messages.
|
|
self._main_to_shell_channel = SocketPair(self._context, "main-reverse")
|
|
self._main_to_shell_channel.on_recv(
|
|
self._shell_channel_io_loop, self._send_on_shell_channel
|
|
)
|
|
|
|
def close(self) -> None:
|
|
"""Stop all subshells and close all resources."""
|
|
assert current_thread().name == SHELL_CHANNEL_THREAD_NAME
|
|
with self._lock_cache:
|
|
while True:
|
|
try:
|
|
_, subshell_thread = self._cache.popitem()
|
|
except KeyError:
|
|
break
|
|
self._stop_subshell(subshell_thread)
|
|
|
|
self.control_to_shell_channel.close()
|
|
self._main_to_shell_channel.close()
|
|
self._shell_channel_to_main.close()
|
|
|
|
def get_shell_channel_to_subshell_pair(self, subshell_id: str | None) -> SocketPair:
|
|
"""Return the inproc socket pair used to send messages from the shell channel
|
|
to a particular subshell or main shell."""
|
|
if subshell_id is None:
|
|
return self._shell_channel_to_main
|
|
with self._lock_cache:
|
|
return self._cache[subshell_id].shell_channel_to_subshell
|
|
|
|
def get_subshell_to_shell_channel_socket(self, subshell_id: str | None) -> zmq.Socket[t.Any]:
|
|
"""Return the socket used by a particular subshell or main shell to send
|
|
messages to the shell channel.
|
|
"""
|
|
if subshell_id is None:
|
|
return self._main_to_shell_channel.from_socket
|
|
with self._lock_cache:
|
|
return self._cache[subshell_id].subshell_to_shell_channel.from_socket
|
|
|
|
def get_shell_channel_to_subshell_socket(self, subshell_id: str | None) -> zmq.Socket[t.Any]:
|
|
"""Return the socket used by the shell channel to send messages to a particular
|
|
subshell or main shell.
|
|
"""
|
|
return self.get_shell_channel_to_subshell_pair(subshell_id).from_socket
|
|
|
|
def get_subshell_aborting(self, subshell_id: str) -> bool:
|
|
"""Get the boolean aborting flag of the specified subshell."""
|
|
with self._lock_cache:
|
|
return self._cache[subshell_id].aborting
|
|
|
|
def get_subshell_asyncio_lock(self, subshell_id: str) -> asyncio.Lock:
|
|
"""Return the asyncio lock belonging to the specified subshell."""
|
|
with self._lock_cache:
|
|
return self._cache[subshell_id].asyncio_lock
|
|
|
|
def list_subshell(self) -> list[str]:
|
|
"""Return list of current subshell ids.
|
|
|
|
Can be called by any subshell using %subshell magic.
|
|
"""
|
|
with self._lock_cache:
|
|
return list(self._cache)
|
|
|
|
def set_on_recv_callback(self, on_recv_callback):
|
|
"""Set the callback used by the main shell and all subshells to receive
|
|
messages sent from the shell channel thread.
|
|
"""
|
|
assert current_thread() == main_thread()
|
|
self._on_recv_callback = on_recv_callback
|
|
self._shell_channel_to_main.on_recv(IOLoop.current(), partial(self._on_recv_callback, None))
|
|
|
|
def set_subshell_aborting(self, subshell_id: str, aborting: bool) -> None:
|
|
"""Set the aborting flag of the specified subshell."""
|
|
with self._lock_cache:
|
|
self._cache[subshell_id].aborting = aborting
|
|
|
|
def subshell_id_from_thread_id(self, thread_id: int) -> str | None:
|
|
"""Return subshell_id of the specified thread_id.
|
|
|
|
Raises RuntimeError if thread_id is not the main shell or a subshell.
|
|
|
|
Only used by %subshell magic so does not have to be fast/cached.
|
|
"""
|
|
with self._lock_cache:
|
|
if thread_id == main_thread().ident:
|
|
return None
|
|
for id, subshell in self._cache.items():
|
|
if subshell.ident == thread_id:
|
|
return id
|
|
msg = f"Thread id {thread_id!r} does not correspond to a subshell of this kernel"
|
|
raise RuntimeError(msg)
|
|
|
|
def _create_subshell(self) -> str:
|
|
"""Create and start a new subshell thread."""
|
|
assert current_thread().name == SHELL_CHANNEL_THREAD_NAME
|
|
|
|
subshell_id = str(uuid.uuid4())
|
|
subshell_thread = SubshellThread(subshell_id, self._context)
|
|
|
|
with self._lock_cache:
|
|
assert subshell_id not in self._cache
|
|
self._cache[subshell_id] = subshell_thread
|
|
|
|
subshell_thread.shell_channel_to_subshell.on_recv(
|
|
subshell_thread.io_loop,
|
|
partial(self._on_recv_callback, subshell_id),
|
|
)
|
|
|
|
subshell_thread.subshell_to_shell_channel.on_recv(
|
|
self._shell_channel_io_loop, self._send_on_shell_channel
|
|
)
|
|
|
|
subshell_thread.start()
|
|
return subshell_id
|
|
|
|
def _delete_subshell(self, subshell_id: str) -> None:
|
|
"""Delete subshell identified by subshell_id.
|
|
|
|
Raises key error if subshell_id not in cache.
|
|
"""
|
|
assert current_thread().name == SHELL_CHANNEL_THREAD_NAME
|
|
|
|
with self._lock_cache:
|
|
subshell_threwad = self._cache.pop(subshell_id)
|
|
|
|
self._stop_subshell(subshell_threwad)
|
|
|
|
def _process_control_request(
|
|
self,
|
|
request: list[t.Any],
|
|
) -> None:
|
|
"""Process a control request message received on the control inproc
|
|
socket and return the reply. Runs in the shell channel thread.
|
|
"""
|
|
assert current_thread().name == SHELL_CHANNEL_THREAD_NAME
|
|
|
|
try:
|
|
decoded = json.loads(request[0])
|
|
type = decoded["type"]
|
|
reply: dict[str, t.Any] = {"status": "ok"}
|
|
|
|
if type == "create":
|
|
reply["subshell_id"] = self._create_subshell()
|
|
elif type == "delete":
|
|
subshell_id = decoded["subshell_id"]
|
|
self._delete_subshell(subshell_id)
|
|
elif type == "list":
|
|
reply["subshell_id"] = self.list_subshell()
|
|
else:
|
|
msg = f"Unrecognised message type {type!r}"
|
|
raise RuntimeError(msg)
|
|
except BaseException as err:
|
|
reply = {
|
|
"status": "error",
|
|
"evalue": str(err),
|
|
}
|
|
|
|
# Return the reply to the control thread.
|
|
self.control_to_shell_channel.to_socket.send_json(reply)
|
|
|
|
def _send_on_shell_channel(self, msg) -> None:
|
|
assert current_thread().name == SHELL_CHANNEL_THREAD_NAME
|
|
self._shell_socket.send_multipart(msg)
|
|
|
|
def _stop_subshell(self, subshell_thread: SubshellThread) -> None:
|
|
"""Stop a subshell thread and close all of its resources."""
|
|
assert current_thread().name == SHELL_CHANNEL_THREAD_NAME
|
|
|
|
if subshell_thread.is_alive():
|
|
subshell_thread.stop()
|
|
subshell_thread.join()
|