From 264ae5cb62c621ceb83838f4c5a183b36b127470 Mon Sep 17 00:00:00 2001 From: BraisedPork <46232992+braisedpork1964@users.noreply.github.com> Date: Thu, 13 Jun 2024 19:55:38 +0800 Subject: [PATCH] Allow access to code from interpreter results (#191) improve code interpreters --- lagent/actions/ipython_interactive.py | 7 ++++++- lagent/actions/ipython_manager.py | 15 +++++++++------ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/lagent/actions/ipython_interactive.py b/lagent/actions/ipython_interactive.py index c39c8893..09dc8702 100644 --- a/lagent/actions/ipython_interactive.py +++ b/lagent/actions/ipython_interactive.py @@ -52,8 +52,13 @@ def __init__( ): super().__init__(description, parser, enable) from IPython import InteractiveShell + from traitlets.config import Config + self.timeout = timeout - self._executor = InteractiveShell() + c = Config() + c.HistoryManager.enabled = False + c.HistoryManager.hist_file = ':memory:' + self._executor = InteractiveShell(config=c) self._highlighting = re.compile(r'\x1b\[\d{,3}(;\d{,3}){,3}m') self._max_out_len = max_out_len if max_out_len >= 0 else None self._use_signals = use_signals diff --git a/lagent/actions/ipython_manager.py b/lagent/actions/ipython_manager.py index 4085f8cd..1b32e190 100644 --- a/lagent/actions/ipython_manager.py +++ b/lagent/actions/ipython_manager.py @@ -46,16 +46,19 @@ def run(self): self.out_q.put('ok') elif isinstance(msg, tuple) and len(msg) == 3: i, session_id, code = msg - res = tm(self.timeout)(self.exec)(session_id, code) + res = self.exec(session_id, code) self.out_q.put((i, session_id, res)) def exec(self, session_id, code): try: + shell = self.session_id2shell[session_id] with StringIO() as io: old_stdout = sys.stdout sys.stdout = io - self.session_id2shell[session_id].run_cell( - self.extract_code(code)) + if self.timeout is False or self.timeout < 0: + shell.run_cell(self.extract_code(code)) + else: + tm(self.timeout)(shell.run_cell)(self.extract_code(code)) sys.stdout = old_stdout output = self._highlighting.sub('', io.getvalue().strip()) output = re.sub(r'^Out\[\d+\]: ', '', output) @@ -63,10 +66,10 @@ def exec(self, session_id, code): output = output.lstrip('-').strip() if output.startswith('TimeoutError'): output = 'The code interpreter encountered a timeout error.' - return {'status': 'FAILURE', 'msg': output} - return {'status': 'SUCCESS', 'value': output} + return {'status': 'FAILURE', 'msg': output, 'code': code} + return {'status': 'SUCCESS', 'value': output, 'code': code} except Exception as e: - return {'status': 'FAILURE', 'msg': str(e)} + return {'status': 'FAILURE', 'msg': str(e), 'code': code} @staticmethod def create_shell(enable_history: bool = False, in_memory: bool = True):